mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 11:20:35 +08:00
refactor model_dtype, fix PPO trainer
This commit is contained in:
@@ -138,11 +138,11 @@ class LlamaFlashAttention2(LlamaAttention):
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
logger.warning_once("The input hidden states seems to be silently casted in float32.")
|
||||
query_states = query_states.to(torch.float16)
|
||||
key_states = key_states.to(torch.float16)
|
||||
value_states = value_states.to(torch.float16)
|
||||
query_states = query_states.to(self.config.torch_dtype)
|
||||
key_states = key_states.to(self.config.torch_dtype)
|
||||
value_states = value_states.to(self.config.torch_dtype)
|
||||
|
||||
if getattr(self, "num_key_value_groups"):
|
||||
if getattr(self, "num_key_value_groups", None):
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user