use fixture

Former-commit-id: 80a9e6bf94cf14fa63e6b6cdf7e1ce13722c8b5e
This commit is contained in:
hiyouga 2024-06-15 20:06:17 +08:00
parent 8053929b20
commit 96b82ccd4d
3 changed files with 24 additions and 7 deletions

View File

@ -163,7 +163,7 @@ class ModelArguments:
) )
infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field( infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
default="auto", default="auto",
metadata={"help": "Data type for model weights and activations at inference."} metadata={"help": "Data type for model weights and activations at inference."},
) )
hf_hub_token: Optional[str] = field( hf_hub_token: Optional[str] = field(
default=None, default=None,

View File

@ -15,6 +15,7 @@
import os import os
from typing import Dict from typing import Dict
import pytest
import torch import torch
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
@ -43,11 +44,15 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"):
assert torch.allclose(state_dict_a[name], state_dict_b[name]) is True assert torch.allclose(state_dict_a[name], state_dict_b[name]) is True
@pytest.fixture
def fix_valuehead_cpu_loading():
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]): def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]):
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")} state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
self.v_head.load_state_dict(state_dict, strict=False) self.v_head.load_state_dict(state_dict, strict=False)
del state_dict del state_dict
AutoModelForCausalLMWithValueHead.post_init = post_init
def test_base(): def test_base():
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS) model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
@ -60,8 +65,8 @@ def test_base():
compare_model(model, ref_model) compare_model(model, ref_model)
@pytest.mark.usefixtures("fix_valuehead_cpu_loading")
def test_valuehead(): def test_valuehead():
AutoModelForCausalLMWithValueHead.post_init = post_init # patch for CPU test
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS) model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
model = load_model( model = load_model(

View File

@ -13,8 +13,9 @@
# limitations under the License. # limitations under the License.
import os import os
from typing import Sequence from typing import Dict, Sequence
import pytest
import torch import torch
from peft import LoraModel, PeftModel from peft import LoraModel, PeftModel
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
@ -71,6 +72,16 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k
assert torch.allclose(state_dict_a[name], state_dict_b[name]) is True assert torch.allclose(state_dict_a[name], state_dict_b[name]) is True
@pytest.fixture
def fix_valuehead_cpu_loading():
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]):
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
self.v_head.load_state_dict(state_dict, strict=False)
del state_dict
AutoModelForCausalLMWithValueHead.post_init = post_init
def test_lora_train_qv_modules(): def test_lora_train_qv_modules():
model_args, _, _, finetuning_args, _ = get_train_args({"lora_target": "q_proj,v_proj", **TRAIN_ARGS}) model_args, _, _, finetuning_args, _ = get_train_args({"lora_target": "q_proj,v_proj", **TRAIN_ARGS})
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
@ -166,6 +177,7 @@ def test_lora_train_new_adapters():
) )
@pytest.mark.usefixtures("fix_valuehead_cpu_loading")
def test_lora_train_valuehead(): def test_lora_train_valuehead():
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS) model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)