mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-27 09:10:35 +08:00
[misc] fix accelerator (#9661)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -12,57 +12,56 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from llamafactory.v1.accelerator.helper import get_current_accelerator
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_available_kernels, apply_kernel
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.rms_norm import npu_rms_norm
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.rope import npu_rope
|
||||
|
||||
|
||||
class TestKernelPlugin(unittest.TestCase):
|
||||
@patch("torch.accelerator.current_accelerator")
|
||||
def test_apply_kernel(self, mock_get_accelerator):
|
||||
get_current_accelerator.cache_clear()
|
||||
mock_device = MagicMock()
|
||||
mock_device.type = "npu"
|
||||
mock_get_accelerator.return_value = mock_device
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||
|
||||
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
||||
original_swiglu_forward = model.model.layers[0].mlp.forward
|
||||
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_kernel
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.rms_norm import npu_rms_norm
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.rope import npu_rope
|
||||
|
||||
apply_kernel(model, npu_rope.NpuRoPEKernel)
|
||||
|
||||
model = apply_kernel(model, npu_rms_norm.NpuRMSNormKernel)
|
||||
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
|
||||
|
||||
model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel)
|
||||
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_accelerator_cache():
|
||||
get_current_accelerator.cache_clear()
|
||||
|
||||
|
||||
class Test_Use_V1_Kernels(unittest.TestCase):
|
||||
@patch("torch.accelerator.current_accelerator")
|
||||
def test_use_v1_kernels(self, mock_get_accelerator):
|
||||
get_current_accelerator.cache_clear()
|
||||
mock_device = MagicMock()
|
||||
mock_device.type = "npu"
|
||||
mock_get_accelerator.return_value = mock_device
|
||||
@patch("torch.accelerator.current_accelerator")
|
||||
def test_apply_kernel(mock_get_accelerator: MagicMock):
|
||||
mock_device = MagicMock()
|
||||
setattr(mock_device, "type", "npu")
|
||||
mock_get_accelerator.return_value = mock_device
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||
|
||||
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
||||
original_swiglu_forward = model.model.layers[0].mlp.forward
|
||||
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
||||
original_swiglu_forward = model.model.layers[0].mlp.forward
|
||||
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_available_kernels
|
||||
apply_kernel(model, npu_rope.NpuRoPEKernel)
|
||||
|
||||
model = apply_available_kernels(model)
|
||||
model = apply_kernel(model, npu_rms_norm.NpuRMSNormKernel)
|
||||
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
|
||||
|
||||
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
|
||||
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
|
||||
model = apply_kernel(model, npu_swiglu.NpuSwiGluKernel)
|
||||
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
|
||||
|
||||
|
||||
@patch("torch.accelerator.current_accelerator")
|
||||
def test_apply_all_kernels(mock_get_accelerator: MagicMock):
|
||||
get_current_accelerator.cache_clear()
|
||||
mock_device = MagicMock()
|
||||
setattr(mock_device, "type", "npu")
|
||||
mock_get_accelerator.return_value = mock_device
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||
|
||||
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
||||
original_swiglu_forward = model.model.layers[0].mlp.forward
|
||||
|
||||
model = apply_available_kernels(model)
|
||||
|
||||
assert model.model.layers[0].input_layernorm is not original_rmsnorm_forward
|
||||
assert model.model.layers[0].mlp.forward is not original_swiglu_forward
|
||||
|
||||
Reference in New Issue
Block a user