mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 09:52:14 +08:00 
			
		
		
		
	tiny fix
Former-commit-id: b5e9711ef375cc323fc083e742cccfc974550416
This commit is contained in:
		
							parent
							
								
									f330b73682
								
							
						
					
					
						commit
						2723438531
					
				@ -182,11 +182,9 @@ def llama_flash_attention_2_forward(
 | 
			
		||||
        query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
 | 
			
		||||
    else:
 | 
			
		||||
        groupsz = q_len
 | 
			
		||||
 | 
			
		||||
    attn_output: torch.Tensor = self._flash_attention_forward(
 | 
			
		||||
        query_states, key_states, value_states, attention_mask, groupsz, dropout=dropout_rate
 | 
			
		||||
        query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if getattr(self.config, "group_size_ratio", None) and self.training:  # shift back
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user