Former-commit-id: 23961bdf6fdbcde64e7b943f699fdeb4ac024043
This commit is contained in:
hiyouga 2024-08-20 00:10:52 +08:00
parent 388f0a6e05
commit d111a324bc
2 changed files with 4 additions and 4 deletions

View File

@ -137,12 +137,12 @@ def get_device_count() -> int:
r"""
Gets the number of available GPU or NPU devices.
"""
if is_torch_npu_available():
if is_torch_xpu_available():
return torch.xpu.device_count()
elif is_torch_npu_available():
return torch.npu.device_count()
elif is_torch_cuda_available():
return torch.cuda.device_count()
elif is_torch_xpu_available():
return torch.xpu.device_count()
else:
return 0

View File

@ -133,9 +133,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
ref_model=ref_model,
tokenizer=tokenizer,
dataset=train_dataset,
optimizer=optimizer,
data_collator=data_collator,
lr_scheduler=scheduler,
optimizer=optimizer,
)
self.args = training_args