mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 23:58:11 +08:00
tiny fix
Former-commit-id: 23961bdf6fdbcde64e7b943f699fdeb4ac024043
This commit is contained in:
parent
388f0a6e05
commit
d111a324bc
@ -137,12 +137,12 @@ def get_device_count() -> int:
|
|||||||
r"""
|
r"""
|
||||||
Gets the number of available GPU or NPU devices.
|
Gets the number of available GPU or NPU devices.
|
||||||
"""
|
"""
|
||||||
if is_torch_npu_available():
|
if is_torch_xpu_available():
|
||||||
|
return torch.xpu.device_count()
|
||||||
|
elif is_torch_npu_available():
|
||||||
return torch.npu.device_count()
|
return torch.npu.device_count()
|
||||||
elif is_torch_cuda_available():
|
elif is_torch_cuda_available():
|
||||||
return torch.cuda.device_count()
|
return torch.cuda.device_count()
|
||||||
elif is_torch_xpu_available():
|
|
||||||
return torch.xpu.device_count()
|
|
||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
@ -133,9 +133,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
ref_model=ref_model,
|
ref_model=ref_model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
dataset=train_dataset,
|
dataset=train_dataset,
|
||||||
|
optimizer=optimizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
lr_scheduler=scheduler,
|
lr_scheduler=scheduler,
|
||||||
optimizer=optimizer,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.args = training_args
|
self.args = training_args
|
||||||
|
Loading…
x
Reference in New Issue
Block a user