update ppo trainer

This commit is contained in:
hiyouga
2023-11-20 21:39:15 +08:00
parent 48211e3799
commit 5021062493
7 changed files with 68 additions and 41 deletions

View File

@@ -16,7 +16,10 @@ try:
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available()
except ImportError:
_is_fp16_available = torch.cuda.is_available()
_is_bf16_available = torch.cuda.is_bf16_supported()
try:
_is_bf16_available = torch.cuda.is_bf16_supported()
except:
_is_bf16_available = False
if TYPE_CHECKING:
from transformers import HfArgumentParser