refactor model_dtype, fix PPO trainer

This commit is contained in:
hiyouga
2023-10-11 23:16:01 +08:00
parent 5310e4d182
commit 2818af0b09
10 changed files with 104 additions and 119 deletions

View File

@@ -138,11 +138,11 @@ class LlamaFlashAttention2(LlamaAttention):
input_dtype = query_states.dtype
if input_dtype == torch.float32:
logger.warning_once("The input hidden states seems to be silently casted in float32.")
query_states = query_states.to(torch.float16)
key_states = key_states.to(torch.float16)
value_states = value_states.to(torch.float16)
query_states = query_states.to(self.config.torch_dtype)
key_states = key_states.to(self.config.torch_dtype)
value_states = value_states.to(self.config.torch_dtype)
if getattr(self, "num_key_value_groups"):
if getattr(self, "num_key_value_groups", None):
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)