From d111a324bce0bf860e493371bc6042b2673edc14 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Tue, 20 Aug 2024 00:10:52 +0800 Subject: [PATCH] tiny fix Former-commit-id: 23961bdf6fdbcde64e7b943f699fdeb4ac024043 --- src/llamafactory/extras/misc.py | 6 +++--- src/llamafactory/train/ppo/trainer.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index 8facad93..4001363c 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -137,12 +137,12 @@ def get_device_count() -> int: r""" Gets the number of available GPU or NPU devices. """ - if is_torch_npu_available(): + if is_torch_xpu_available(): + return torch.xpu.device_count() + elif is_torch_npu_available(): return torch.npu.device_count() elif is_torch_cuda_available(): return torch.cuda.device_count() - elif is_torch_xpu_available(): - return torch.xpu.device_count() else: return 0 diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 3b792061..eeb690ca 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -133,9 +133,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer): ref_model=ref_model, tokenizer=tokenizer, dataset=train_dataset, + optimizer=optimizer, data_collator=data_collator, lr_scheduler=scheduler, - optimizer=optimizer, ) self.args = training_args