mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	fix saving custom code
Former-commit-id: 3f8f40bffd4f61fcc045f5f8a07420f3b46d0f7a
This commit is contained in:
		
							parent
							
								
									c61de6f669
								
							
						
					
					
						commit
						e9736b2ba0
					
				@ -11,7 +11,7 @@ from transformers import (
 | 
			
		||||
from transformers.utils import check_min_version
 | 
			
		||||
from transformers.utils.versions import require_version
 | 
			
		||||
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
 | 
			
		||||
from transformers.tokenization_utils import PreTrainedTokenizer
 | 
			
		||||
from transformers.tokenization_utils import PreTrainedTokenizerBase
 | 
			
		||||
from trl import AutoModelForCausalLMWithValueHead
 | 
			
		||||
 | 
			
		||||
from llmtuner.extras.logging import get_logger
 | 
			
		||||
@ -36,7 +36,7 @@ def load_model_and_tokenizer(
 | 
			
		||||
    finetuning_args: FinetuningArguments,
 | 
			
		||||
    is_trainable: Optional[bool] = False,
 | 
			
		||||
    stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
 | 
			
		||||
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
 | 
			
		||||
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Loads pretrained model and tokenizer.
 | 
			
		||||
 | 
			
		||||
@ -113,12 +113,12 @@ def load_model_and_tokenizer(
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Register auto class to save the custom code files.
 | 
			
		||||
    if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map and isinstance(config, PretrainedConfig):
 | 
			
		||||
    if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
 | 
			
		||||
        config.__class__.register_for_auto_class()
 | 
			
		||||
    if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map and isinstance(tokenizer, PreTrainedTokenizer):
 | 
			
		||||
        tokenizer.__class__.register_for_auto_class()
 | 
			
		||||
    if hasattr(config, "auto_map") and "AutoModelForCausalLM" in config.auto_map and isinstance(model, PreTrainedModel):
 | 
			
		||||
    if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
 | 
			
		||||
        model.__class__.register_for_auto_class()
 | 
			
		||||
    if isinstance(tokenizer, PreTrainedTokenizerBase) and "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
 | 
			
		||||
        tokenizer.__class__.register_for_auto_class()
 | 
			
		||||
 | 
			
		||||
    # Initialize adapters
 | 
			
		||||
    model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
 | 
			
		||||
 | 
			
		||||
@ -300,6 +300,45 @@ class BaichuanPreTrainedModel(PreTrainedModel):
 | 
			
		||||
        if isinstance(module, BaichuanModel):
 | 
			
		||||
            module.gradient_checkpointing = value
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _convert_to_standard_cache(
 | 
			
		||||
        past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
 | 
			
		||||
    ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
 | 
			
		||||
        """
 | 
			
		||||
        Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
 | 
			
		||||
        num_heads, ...]))
 | 
			
		||||
        """
 | 
			
		||||
        batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
 | 
			
		||||
        num_heads = batch_size_times_num_heads // batch_size
 | 
			
		||||
        # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
 | 
			
		||||
        # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
 | 
			
		||||
        return tuple(
 | 
			
		||||
            (
 | 
			
		||||
                layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
 | 
			
		||||
                layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
 | 
			
		||||
            )
 | 
			
		||||
            for layer_past in past_key_value
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _convert_to_baichuan_cache(
 | 
			
		||||
        past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
 | 
			
		||||
    ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
 | 
			
		||||
        """
 | 
			
		||||
        Converts the cache to the format expected by Baichuan, i.e. to tuple(tuple([batch_size * num_heads, ...]))
 | 
			
		||||
        """
 | 
			
		||||
        batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
 | 
			
		||||
        batch_size_times_num_heads = batch_size * num_heads
 | 
			
		||||
        # key:  [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
 | 
			
		||||
        # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
 | 
			
		||||
        return tuple(
 | 
			
		||||
            (
 | 
			
		||||
                layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
 | 
			
		||||
                layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
 | 
			
		||||
            )
 | 
			
		||||
            for layer_past in past_key_value
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaichuanModel(BaichuanPreTrainedModel):
 | 
			
		||||
 | 
			
		||||
@ -318,9 +357,9 @@ class BaichuanModel(BaichuanPreTrainedModel):
 | 
			
		||||
 | 
			
		||||
    def get_input_embeddings(self):
 | 
			
		||||
        return self.embed_tokens
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
    def set_input_embeddings(self, value):
 | 
			
		||||
        self.embed_tokens = value  
 | 
			
		||||
        self.embed_tokens = value
 | 
			
		||||
 | 
			
		||||
    def build_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
 | 
			
		||||
        return build_alibi_tensor(attention_mask, num_heads, dtype)
 | 
			
		||||
@ -468,7 +507,7 @@ class BaichuanModel(BaichuanPreTrainedModel):
 | 
			
		||||
            hidden_states=all_hidden_states,
 | 
			
		||||
            attentions=all_self_attns,
 | 
			
		||||
        )
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaichuanForCausalLM(BaichuanPreTrainedModel):
 | 
			
		||||
 | 
			
		||||
@ -498,7 +537,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
 | 
			
		||||
 | 
			
		||||
    def get_decoder(self):
 | 
			
		||||
        return self.model
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
@ -528,7 +567,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
 | 
			
		||||
            output_attentions=output_attentions,
 | 
			
		||||
            output_hidden_states=output_hidden_states,
 | 
			
		||||
            return_dict=return_dict,
 | 
			
		||||
        )   
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        hidden_states = outputs[0]
 | 
			
		||||
        logits = self.lm_head(hidden_states)
 | 
			
		||||
@ -559,11 +598,20 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def prepare_inputs_for_generation(
 | 
			
		||||
        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
 | 
			
		||||
    ):  
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor,
 | 
			
		||||
        past_key_values: Optional[torch.Tensor] = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.Tensor] = None,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ) -> dict:
 | 
			
		||||
        if past_key_values:
 | 
			
		||||
            input_ids = input_ids[:, -1:]
 | 
			
		||||
 | 
			
		||||
            # the cache may be in the standard format (e.g. in contrastive search)
 | 
			
		||||
            if past_key_values[0][0].shape[0] == input_ids.shape[0]:
 | 
			
		||||
                past_key_values = self._convert_to_baichuan_cache(past_key_values)
 | 
			
		||||
 | 
			
		||||
        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
 | 
			
		||||
        if inputs_embeds is not None and past_key_values is None:
 | 
			
		||||
            model_inputs = {"inputs_embeds": inputs_embeds}
 | 
			
		||||
@ -571,21 +619,38 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
 | 
			
		||||
            model_inputs = {"input_ids": input_ids}
 | 
			
		||||
 | 
			
		||||
        model_inputs.update(
 | 
			
		||||
            {   
 | 
			
		||||
            {
 | 
			
		||||
                "past_key_values": past_key_values,
 | 
			
		||||
                "use_cache": kwargs.get("use_cache"),
 | 
			
		||||
                "attention_mask": attention_mask,
 | 
			
		||||
            }   
 | 
			
		||||
        )   
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
        return model_inputs
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _reorder_cache(past_key_values, beam_idx):
 | 
			
		||||
        return tuple(
 | 
			
		||||
            tuple(past_state.index_select(0, beam_idx) for past_state in layer_past)
 | 
			
		||||
            for layer_past in past_key_values
 | 
			
		||||
        )
 | 
			
		||||
    def _reorder_cache(
 | 
			
		||||
        self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
 | 
			
		||||
    ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
 | 
			
		||||
        """
 | 
			
		||||
        This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
 | 
			
		||||
        [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
 | 
			
		||||
        beam_idx at every generation step.
 | 
			
		||||
 | 
			
		||||
        Output shares the same memory storage as `past`.
 | 
			
		||||
        """
 | 
			
		||||
        standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
 | 
			
		||||
 | 
			
		||||
        # Get a copy of `beam_idx` on all the devices where we need those indices.
 | 
			
		||||
        device_to_beam_idx = {
 | 
			
		||||
            past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
 | 
			
		||||
        }
 | 
			
		||||
        reordered_past = tuple(
 | 
			
		||||
            (
 | 
			
		||||
                layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
 | 
			
		||||
                layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
 | 
			
		||||
            )
 | 
			
		||||
            for layer_past in standardized_past
 | 
			
		||||
        )
 | 
			
		||||
        return self._convert_to_baichuan_cache(reordered_past)
 | 
			
		||||
 | 
			
		||||
    def quantize(self, bits: int):
 | 
			
		||||
        try:
 | 
			
		||||
@ -594,7 +659,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
 | 
			
		||||
            raise ImportError(
 | 
			
		||||
                f"Needs QLinear to run quantize."
 | 
			
		||||
            )
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        for layer in self.model.layers:
 | 
			
		||||
            layer.self_attn.W_pack = QLinear(
 | 
			
		||||
                bits=bits,
 | 
			
		||||
@ -621,7 +686,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
 | 
			
		||||
                weight=layer.mlp.up_proj.weight,
 | 
			
		||||
                bias = None,
 | 
			
		||||
            )
 | 
			
		||||
        return self 
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0):
 | 
			
		||||
        max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user