add tests

This commit is contained in:
hiyouga
2024-06-15 19:51:20 +08:00
parent 572d8bbfdd
commit 1b834f50be
8 changed files with 166 additions and 14 deletions

View File

@@ -18,7 +18,9 @@ from typing import Sequence
import torch
from peft import LoraModel, PeftModel
from transformers import AutoModelForCausalLM
from trl import AutoModelForCausalLMWithValueHead
from llamafactory.extras.misc import get_current_device
from llamafactory.hparams import get_infer_args, get_train_args
from llamafactory.model import load_model, load_tokenizer
@@ -27,6 +29,8 @@ TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA_ADAPTER = os.environ.get("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-lora")
TINY_LLAMA_VALUEHEAD = os.environ.get("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead")
TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA,
"stage": "sft",
@@ -67,10 +71,29 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k
assert torch.allclose(state_dict_a[name], state_dict_b[name]) is True
def test_lora_train_qv_modules():
model_args, _, _, finetuning_args, _ = get_train_args({"lora_target": "q_proj,v_proj", **TRAIN_ARGS})
tokenizer_module = load_tokenizer(model_args)
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
linear_modules = set()
for name, param in model.named_parameters():
if any(module in name for module in ["lora_A", "lora_B"]):
linear_modules.add(name.split(".lora_", maxsplit=1)[0].split(".")[-1])
assert param.requires_grad is True
assert param.dtype == torch.float32
else:
assert param.requires_grad is False
assert param.dtype == torch.float16
assert linear_modules == {"q_proj", "v_proj"}
def test_lora_train_all_modules():
model_args, _, _, finetuning_args, _ = get_train_args({"lora_target": "all", **TRAIN_ARGS})
tokenizer_module = load_tokenizer(model_args)
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
linear_modules = set()
for name, param in model.named_parameters():
if any(module in name for module in ["lora_A", "lora_B"]):
@@ -90,6 +113,7 @@ def test_lora_train_extra_modules():
)
tokenizer_module = load_tokenizer(model_args)
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
extra_modules = set()
for name, param in model.named_parameters():
if any(module in name for module in ["lora_A", "lora_B"]):
@@ -113,7 +137,9 @@ def test_lora_train_old_adapters():
tokenizer_module = load_tokenizer(model_args)
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
base_model = AutoModelForCausalLM.from_pretrained(TINY_LLAMA, torch_dtype=model.dtype, device_map=model.device)
base_model = AutoModelForCausalLM.from_pretrained(
TINY_LLAMA, torch_dtype=torch.float16, device_map=get_current_device()
)
ref_model = PeftModel.from_pretrained(base_model, TINY_LLAMA_ADAPTER, is_trainable=True)
for param in filter(lambda p: p.requires_grad, ref_model.parameters()):
param.data = param.data.to(torch.float32)
@@ -128,7 +154,9 @@ def test_lora_train_new_adapters():
tokenizer_module = load_tokenizer(model_args)
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
base_model = AutoModelForCausalLM.from_pretrained(TINY_LLAMA, torch_dtype=model.dtype, device_map=model.device)
base_model = AutoModelForCausalLM.from_pretrained(
TINY_LLAMA, torch_dtype=torch.float16, device_map=get_current_device()
)
ref_model = PeftModel.from_pretrained(base_model, TINY_LLAMA_ADAPTER, is_trainable=True)
for param in filter(lambda p: p.requires_grad, ref_model.parameters()):
param.data = param.data.to(torch.float32)
@@ -138,17 +166,31 @@ def test_lora_train_new_adapters():
)
def test_lora_train_valuehead():
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
tokenizer_module = load_tokenizer(model_args)
model = load_model(
tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True, add_valuehead=True
)
ref_model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(
TINY_LLAMA_VALUEHEAD, torch_dtype=torch.float16, device_map=get_current_device()
)
state_dict = model.state_dict()
ref_state_dict = ref_model.state_dict()
assert torch.allclose(state_dict["v_head.summary.weight"], ref_state_dict["v_head.summary.weight"])
assert torch.allclose(state_dict["v_head.summary.bias"], ref_state_dict["v_head.summary.bias"])
def test_lora_inference():
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
tokenizer_module = load_tokenizer(model_args)
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False)
base_model = AutoModelForCausalLM.from_pretrained(TINY_LLAMA, torch_dtype=model.dtype, device_map=model.device)
base_model = AutoModelForCausalLM.from_pretrained(
TINY_LLAMA, torch_dtype=torch.float16, device_map=get_current_device()
)
ref_model: "LoraModel" = PeftModel.from_pretrained(base_model, TINY_LLAMA_ADAPTER)
ref_model = ref_model.merge_and_unload()
compare_model(model, ref_model)
for name, param in model.named_parameters():
assert param.requires_grad is False
assert param.dtype == torch.float16
assert "lora" not in name