mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-14 10:56:56 +08:00
add unittest
Former-commit-id: 8a1f0c5f922989e08a19c65de0b2c4afd2a5771f
This commit is contained in:
@@ -16,8 +16,7 @@ import os
|
||||
|
||||
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
|
||||
|
||||
from llamafactory.hparams import get_infer_args
|
||||
from llamafactory.model import load_model, load_tokenizer
|
||||
from llamafactory.train.test_utils import load_infer_model
|
||||
|
||||
|
||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
@@ -42,9 +41,7 @@ def test_attention():
|
||||
"fa2": "LlamaFlashAttention2",
|
||||
}
|
||||
for requested_attention in attention_available:
|
||||
model_args, _, finetuning_args, _ = get_infer_args({"flash_attn": requested_attention, **INFER_ARGS})
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args)
|
||||
model = load_infer_model(flash_attn=requested_attention, **INFER_ARGS)
|
||||
for module in model.modules():
|
||||
if "Attention" in module.__class__.__name__:
|
||||
assert module.__class__.__name__ == llama_attention_classes[requested_attention]
|
||||
|
||||
@@ -17,8 +17,7 @@ import os
|
||||
import torch
|
||||
|
||||
from llamafactory.extras.misc import get_current_device
|
||||
from llamafactory.hparams import get_train_args
|
||||
from llamafactory.model import load_model, load_tokenizer
|
||||
from llamafactory.train.test_utils import load_train_model
|
||||
|
||||
|
||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
@@ -41,34 +40,26 @@ TRAIN_ARGS = {
|
||||
|
||||
|
||||
def test_checkpointing_enable():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args({"disable_gradient_checkpointing": False, **TRAIN_ARGS})
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
model = load_train_model(disable_gradient_checkpointing=False, **TRAIN_ARGS)
|
||||
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
|
||||
assert getattr(module, "gradient_checkpointing") is True
|
||||
|
||||
|
||||
def test_checkpointing_disable():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args({"disable_gradient_checkpointing": True, **TRAIN_ARGS})
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
model = load_train_model(disable_gradient_checkpointing=True, **TRAIN_ARGS)
|
||||
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
|
||||
assert getattr(module, "gradient_checkpointing") is False
|
||||
|
||||
|
||||
def test_upcast_layernorm():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args({"upcast_layernorm": True, **TRAIN_ARGS})
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
model = load_train_model(upcast_layernorm=True, **TRAIN_ARGS)
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and "norm" in name:
|
||||
assert param.dtype == torch.float32
|
||||
|
||||
|
||||
def test_upcast_lmhead_output():
|
||||
model_args, _, _, finetuning_args, _ = get_train_args({"upcast_lmhead_output": True, **TRAIN_ARGS})
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||
model = load_train_model(upcast_lmhead_output=True, **TRAIN_ARGS)
|
||||
inputs = torch.randn((1, 16), dtype=torch.float16, device=get_current_device())
|
||||
outputs: "torch.Tensor" = model.get_output_embeddings()(inputs)
|
||||
assert outputs.dtype == torch.float32
|
||||
|
||||
Reference in New Issue
Block a user