From daebca2368f1f46ed5c2c86fb276457df932827a Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Tue, 20 Aug 2024 00:10:52 +0800 Subject: [PATCH] tiny fix Former-commit-id: c8b4c7fee5398654683b713ad5c03b5daf13218a --- src/llamafactory/extras/misc.py | 6 +++--- src/llamafactory/train/ppo/trainer.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index 8facad93..4001363c 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -137,12 +137,12 @@ def get_device_count() -> int: r""" Gets the number of available GPU or NPU devices. """ - if is_torch_npu_available(): + if is_torch_xpu_available(): + return torch.xpu.device_count() + elif is_torch_npu_available(): return torch.npu.device_count() elif is_torch_cuda_available(): return torch.cuda.device_count() - elif is_torch_xpu_available(): - return torch.xpu.device_count() else: return 0 diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 3b792061..eeb690ca 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -133,9 +133,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer): ref_model=ref_model, tokenizer=tokenizer, dataset=train_dataset, + optimizer=optimizer, data_collator=data_collator, lr_scheduler=scheduler, - optimizer=optimizer, ) self.args = training_args