mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
tiny fix
Former-commit-id: c8b4c7fee5398654683b713ad5c03b5daf13218a
This commit is contained in:
parent
a7604a95c1
commit
daebca2368
@ -137,12 +137,12 @@ def get_device_count() -> int:
|
||||
r"""
|
||||
Gets the number of available GPU or NPU devices.
|
||||
"""
|
||||
if is_torch_npu_available():
|
||||
if is_torch_xpu_available():
|
||||
return torch.xpu.device_count()
|
||||
elif is_torch_npu_available():
|
||||
return torch.npu.device_count()
|
||||
elif is_torch_cuda_available():
|
||||
return torch.cuda.device_count()
|
||||
elif is_torch_xpu_available():
|
||||
return torch.xpu.device_count()
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
@ -133,9 +133,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
ref_model=ref_model,
|
||||
tokenizer=tokenizer,
|
||||
dataset=train_dataset,
|
||||
optimizer=optimizer,
|
||||
data_collator=data_collator,
|
||||
lr_scheduler=scheduler,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
|
||||
self.args = training_args
|
||||
|
Loading…
x
Reference in New Issue
Block a user