update ppo trainer

Former-commit-id: 5021062493
This commit is contained in:
hiyouga
2023-11-20 21:39:15 +08:00
parent d72f123851
commit f06c4c8f7a
7 changed files with 68 additions and 41 deletions

View File

@@ -16,7 +16,10 @@ try:
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available()
except ImportError:
_is_fp16_available = torch.cuda.is_available()
_is_bf16_available = torch.cuda.is_bf16_supported()
try:
_is_bf16_available = torch.cuda.is_bf16_supported()
except:
_is_bf16_available = False
if TYPE_CHECKING:
from transformers import HfArgumentParser