diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index 8facad93..4001363c 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -137,12 +137,12 @@ def get_device_count() -> int: r""" Gets the number of available GPU or NPU devices. """ - if is_torch_npu_available(): + if is_torch_xpu_available(): + return torch.xpu.device_count() + elif is_torch_npu_available(): return torch.npu.device_count() elif is_torch_cuda_available(): return torch.cuda.device_count() - elif is_torch_xpu_available(): - return torch.xpu.device_count() else: return 0 diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 3b792061..eeb690ca 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -133,9 +133,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer): ref_model=ref_model, tokenizer=tokenizer, dataset=train_dataset, + optimizer=optimizer, data_collator=data_collator, lr_scheduler=scheduler, - optimizer=optimizer, ) self.args = training_args