use fixture

Former-commit-id: 10761985691b9f934f7689c1f82aa6dd68febcca
This commit is contained in:
hiyouga 2024-06-15 20:06:17 +08:00
parent 7f90b0cd20
commit 14f7bfc545
3 changed files with 24 additions and 7 deletions

View File

@ -163,7 +163,7 @@ class ModelArguments:
) )
infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field( infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
default="auto", default="auto",
metadata={"help": "Data type for model weights and activations at inference."} metadata={"help": "Data type for model weights and activations at inference."},
) )
hf_hub_token: Optional[str] = field( hf_hub_token: Optional[str] = field(
default=None, default=None,

View File

@ -15,6 +15,7 @@
import os import os
from typing import Dict from typing import Dict
import pytest
import torch import torch
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
@ -43,11 +44,15 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"):
assert torch.allclose(state_dict_a[name], state_dict_b[name]) is True assert torch.allclose(state_dict_a[name], state_dict_b[name]) is True
@pytest.fixture
def fix_valuehead_cpu_loading():
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]): def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]):
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")} state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
self.v_head.load_state_dict(state_dict, strict=False) self.v_head.load_state_dict(state_dict, strict=False)
del state_dict del state_dict
AutoModelForCausalLMWithValueHead.post_init = post_init
def test_base(): def test_base():
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS) model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
@ -60,8 +65,8 @@ def test_base():
compare_model(model, ref_model) compare_model(model, ref_model)
@pytest.mark.usefixtures("fix_valuehead_cpu_loading")
def test_valuehead(): def test_valuehead():
AutoModelForCausalLMWithValueHead.post_init = post_init # patch for CPU test
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS) model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
model = load_model( model = load_model(

View File

@ -13,8 +13,9 @@
# limitations under the License. # limitations under the License.
import os import os
from typing import Sequence from typing import Dict, Sequence
import pytest
import torch import torch
from peft import LoraModel, PeftModel from peft import LoraModel, PeftModel
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
@ -71,6 +72,16 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k
assert torch.allclose(state_dict_a[name], state_dict_b[name]) is True assert torch.allclose(state_dict_a[name], state_dict_b[name]) is True
@pytest.fixture
def fix_valuehead_cpu_loading():
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]):
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
self.v_head.load_state_dict(state_dict, strict=False)
del state_dict
AutoModelForCausalLMWithValueHead.post_init = post_init
def test_lora_train_qv_modules(): def test_lora_train_qv_modules():
model_args, _, _, finetuning_args, _ = get_train_args({"lora_target": "q_proj,v_proj", **TRAIN_ARGS}) model_args, _, _, finetuning_args, _ = get_train_args({"lora_target": "q_proj,v_proj", **TRAIN_ARGS})
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
@ -166,6 +177,7 @@ def test_lora_train_new_adapters():
) )
@pytest.mark.usefixtures("fix_valuehead_cpu_loading")
def test_lora_train_valuehead(): def test_lora_train_valuehead():
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS) model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)