support full-parameter PPO

This commit is contained in:
hiyouga
2023-11-16 02:08:04 +08:00
parent 8350bcf85d
commit ce78303600
20 changed files with 288 additions and 145 deletions

View File

@@ -64,6 +64,16 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
return trainable_params, all_param
def get_current_device() -> str:
import accelerate
from accelerate import Accelerator
dummy_accelerator = Accelerator()
if accelerate.utils.is_xpu_available():
return "xpu:{}".format(dummy_accelerator.local_process_index)
else:
return dummy_accelerator.local_process_index if torch.cuda.is_available() else "cpu"
def get_logits_processor() -> "LogitsProcessorList":
r"""
Gets logits processor that removes NaN and Inf logits.