mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
use fixture
Former-commit-id: 80a9e6bf94cf14fa63e6b6cdf7e1ce13722c8b5e
This commit is contained in:
parent
8053929b20
commit
96b82ccd4d
@ -163,7 +163,7 @@ class ModelArguments:
|
|||||||
)
|
)
|
||||||
infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
|
infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
|
||||||
default="auto",
|
default="auto",
|
||||||
metadata={"help": "Data type for model weights and activations at inference."}
|
metadata={"help": "Data type for model weights and activations at inference."},
|
||||||
)
|
)
|
||||||
hf_hub_token: Optional[str] = field(
|
hf_hub_token: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
@ -43,11 +44,15 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"):
|
|||||||
assert torch.allclose(state_dict_a[name], state_dict_b[name]) is True
|
assert torch.allclose(state_dict_a[name], state_dict_b[name]) is True
|
||||||
|
|
||||||
|
|
||||||
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]):
|
@pytest.fixture
|
||||||
|
def fix_valuehead_cpu_loading():
|
||||||
|
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]):
|
||||||
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
|
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
|
||||||
self.v_head.load_state_dict(state_dict, strict=False)
|
self.v_head.load_state_dict(state_dict, strict=False)
|
||||||
del state_dict
|
del state_dict
|
||||||
|
|
||||||
|
AutoModelForCausalLMWithValueHead.post_init = post_init
|
||||||
|
|
||||||
|
|
||||||
def test_base():
|
def test_base():
|
||||||
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
|
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
|
||||||
@ -60,8 +65,8 @@ def test_base():
|
|||||||
compare_model(model, ref_model)
|
compare_model(model, ref_model)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("fix_valuehead_cpu_loading")
|
||||||
def test_valuehead():
|
def test_valuehead():
|
||||||
AutoModelForCausalLMWithValueHead.post_init = post_init # patch for CPU test
|
|
||||||
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
|
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
model = load_model(
|
model = load_model(
|
||||||
|
@ -13,8 +13,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Sequence
|
from typing import Dict, Sequence
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from peft import LoraModel, PeftModel
|
from peft import LoraModel, PeftModel
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
@ -71,6 +72,16 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k
|
|||||||
assert torch.allclose(state_dict_a[name], state_dict_b[name]) is True
|
assert torch.allclose(state_dict_a[name], state_dict_b[name]) is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fix_valuehead_cpu_loading():
|
||||||
|
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]):
|
||||||
|
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
|
||||||
|
self.v_head.load_state_dict(state_dict, strict=False)
|
||||||
|
del state_dict
|
||||||
|
|
||||||
|
AutoModelForCausalLMWithValueHead.post_init = post_init
|
||||||
|
|
||||||
|
|
||||||
def test_lora_train_qv_modules():
|
def test_lora_train_qv_modules():
|
||||||
model_args, _, _, finetuning_args, _ = get_train_args({"lora_target": "q_proj,v_proj", **TRAIN_ARGS})
|
model_args, _, _, finetuning_args, _ = get_train_args({"lora_target": "q_proj,v_proj", **TRAIN_ARGS})
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
@ -166,6 +177,7 @@ def test_lora_train_new_adapters():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("fix_valuehead_cpu_loading")
|
||||||
def test_lora_train_valuehead():
|
def test_lora_train_valuehead():
|
||||||
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
|
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user