mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-24 23:02:49 +08:00
add test cases
Former-commit-id: b27269bd2b52fb9d43cde8a8b7f293099b0127a2
This commit is contained in:
parent
d4ce280fbc
commit
a3f4925c2c
@ -52,7 +52,7 @@ class VllmEngine(BaseEngine):
|
|||||||
"model": model_args.model_name_or_path,
|
"model": model_args.model_name_or_path,
|
||||||
"trust_remote_code": True,
|
"trust_remote_code": True,
|
||||||
"download_dir": model_args.cache_dir,
|
"download_dir": model_args.cache_dir,
|
||||||
"dtype": model_args.vllm_dtype,
|
"dtype": model_args.infer_dtype,
|
||||||
"max_model_len": model_args.vllm_maxlen,
|
"max_model_len": model_args.vllm_maxlen,
|
||||||
"tensor_parallel_size": get_device_count() or 1,
|
"tensor_parallel_size": get_device_count() or 1,
|
||||||
"gpu_memory_utilization": model_args.vllm_gpu_util,
|
"gpu_memory_utilization": model_args.vllm_gpu_util,
|
||||||
|
@ -136,10 +136,6 @@ class ModelArguments:
|
|||||||
default=8,
|
default=8,
|
||||||
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
|
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
|
||||||
)
|
)
|
||||||
vllm_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
|
|
||||||
default="auto",
|
|
||||||
metadata={"help": "Data type for model weights and activations in the vLLM engine."},
|
|
||||||
)
|
|
||||||
offload_folder: str = field(
|
offload_folder: str = field(
|
||||||
default="offload",
|
default="offload",
|
||||||
metadata={"help": "Path to offload model weights."},
|
metadata={"help": "Path to offload model weights."},
|
||||||
@ -148,6 +144,10 @@ class ModelArguments:
|
|||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether or not to use KV cache in generation."},
|
metadata={"help": "Whether or not to use KV cache in generation."},
|
||||||
)
|
)
|
||||||
|
infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
|
||||||
|
default="auto",
|
||||||
|
metadata={"help": "Data type for model weights and activations at inference."}
|
||||||
|
)
|
||||||
hf_hub_token: Optional[str] = field(
|
hf_hub_token: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Auth token to log in with Hugging Face Hub."},
|
metadata={"help": "Auth token to log in with Hugging Face Hub."},
|
||||||
|
@ -25,8 +25,12 @@ def _setup_full_tuning(
|
|||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
|
is_trainable: bool,
|
||||||
cast_trainable_params_to_fp32: bool,
|
cast_trainable_params_to_fp32: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if not is_trainable:
|
||||||
|
return
|
||||||
|
|
||||||
logger.info("Fine-tuning method: Full")
|
logger.info("Fine-tuning method: Full")
|
||||||
forbidden_modules = set()
|
forbidden_modules = set()
|
||||||
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
||||||
@ -47,8 +51,12 @@ def _setup_freeze_tuning(
|
|||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
|
is_trainable: bool,
|
||||||
cast_trainable_params_to_fp32: bool,
|
cast_trainable_params_to_fp32: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if not is_trainable:
|
||||||
|
return
|
||||||
|
|
||||||
logger.info("Fine-tuning method: Freeze")
|
logger.info("Fine-tuning method: Freeze")
|
||||||
if model_args.visual_inputs:
|
if model_args.visual_inputs:
|
||||||
config = model.config.text_config
|
config = model.config.text_config
|
||||||
@ -132,7 +140,9 @@ def _setup_lora_tuning(
|
|||||||
is_trainable: bool,
|
is_trainable: bool,
|
||||||
cast_trainable_params_to_fp32: bool,
|
cast_trainable_params_to_fp32: bool,
|
||||||
) -> "PeftModel":
|
) -> "PeftModel":
|
||||||
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
if is_trainable:
|
||||||
|
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
||||||
|
|
||||||
adapter_to_resume = None
|
adapter_to_resume = None
|
||||||
|
|
||||||
if model_args.adapter_name_or_path is not None:
|
if model_args.adapter_name_or_path is not None:
|
||||||
@ -173,6 +183,8 @@ def _setup_lora_tuning(
|
|||||||
offload_folder=model_args.offload_folder,
|
offload_folder=model_args.offload_folder,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
||||||
|
|
||||||
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
||||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||||
target_modules = find_all_linear_modules(model, finetuning_args.freeze_vision_tower)
|
target_modules = find_all_linear_modules(model, finetuning_args.freeze_vision_tower)
|
||||||
@ -227,9 +239,6 @@ def _setup_lora_tuning(
|
|||||||
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
||||||
param.data = param.data.to(torch.float32)
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
if model_args.adapter_name_or_path is not None:
|
|
||||||
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -247,29 +256,27 @@ def init_adapter(
|
|||||||
|
|
||||||
Note that the trainable parameters must be cast to float32.
|
Note that the trainable parameters must be cast to float32.
|
||||||
"""
|
"""
|
||||||
if (not is_trainable) and model_args.adapter_name_or_path is None:
|
if is_trainable and getattr(model, "quantization_method", None) and finetuning_args.finetuning_type != "lora":
|
||||||
logger.info("Adapter is not found at evaluation, load the base model.")
|
raise ValueError("Quantized models can only be used for the LoRA tuning.")
|
||||||
return model
|
|
||||||
|
|
||||||
if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None):
|
if not is_trainable:
|
||||||
raise ValueError("You can only use lora for quantized models.")
|
cast_trainable_params_to_fp32 = False
|
||||||
|
elif is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
||||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
|
||||||
logger.info("ZeRO3/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.")
|
logger.info("ZeRO3/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.")
|
||||||
cast_trainable_params_to_fp32 = False
|
cast_trainable_params_to_fp32 = False
|
||||||
else:
|
else:
|
||||||
logger.info("Upcasting trainable params to float32.")
|
logger.info("Upcasting trainable params to float32.")
|
||||||
cast_trainable_params_to_fp32 = True
|
cast_trainable_params_to_fp32 = True
|
||||||
|
|
||||||
if is_trainable and finetuning_args.finetuning_type == "full":
|
if finetuning_args.finetuning_type == "full":
|
||||||
_setup_full_tuning(model, model_args, finetuning_args, cast_trainable_params_to_fp32)
|
_setup_full_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
|
||||||
|
elif finetuning_args.finetuning_type == "freeze":
|
||||||
if is_trainable and finetuning_args.finetuning_type == "freeze":
|
_setup_freeze_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
|
||||||
_setup_freeze_tuning(model, model_args, finetuning_args, cast_trainable_params_to_fp32)
|
elif finetuning_args.finetuning_type == "lora":
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "lora":
|
|
||||||
model = _setup_lora_tuning(
|
model = _setup_lora_tuning(
|
||||||
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
|
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Unknown finetuning type: {}.".format(finetuning_args.finetuning_type))
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
@ -44,7 +44,10 @@ def patch_config(
|
|||||||
is_trainable: bool,
|
is_trainable: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
||||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
if model_args.infer_dtype == "auto":
|
||||||
|
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||||
|
else:
|
||||||
|
model_args.compute_dtype = getattr(torch, model_args.infer_dtype)
|
||||||
|
|
||||||
if is_torch_npu_available():
|
if is_torch_npu_available():
|
||||||
use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"]
|
use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"]
|
||||||
|
@ -135,8 +135,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||||
self.is_chatglm_model = getattr(unwrapped_model.config, "model_type", None) == "chatglm"
|
self.is_chatglm_model = getattr(unwrapped_model.config, "model_type", None) == "chatglm"
|
||||||
|
|
||||||
device_type = unwrapped_model.pretrained_model.device.type
|
self.amp_context = torch.autocast(self.current_device.type, dtype=self.model_args.compute_dtype)
|
||||||
self.amp_context = torch.autocast(device_type, dtype=model_args.compute_dtype)
|
|
||||||
warnings.simplefilter("ignore") # remove gc warnings on ref model
|
warnings.simplefilter("ignore") # remove gc warnings on ref model
|
||||||
|
|
||||||
if finetuning_args.reward_model_type == "full":
|
if finetuning_args.reward_model_type == "full":
|
||||||
|
32
tests/model/test_base.py
Normal file
32
tests/model/test_base.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
from llamafactory.hparams import get_infer_args
|
||||||
|
from llamafactory.model import load_model, load_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||||
|
|
||||||
|
INFER_ARGS = {
|
||||||
|
"model_name_or_path": TINY_LLAMA,
|
||||||
|
"template": "llama3",
|
||||||
|
"infer_dtype": "float16",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"):
|
||||||
|
state_dict_a = model_a.state_dict()
|
||||||
|
state_dict_b = model_b.state_dict()
|
||||||
|
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
|
||||||
|
for name in state_dict_a.keys():
|
||||||
|
assert torch.allclose(state_dict_a[name], state_dict_b[name]) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_base():
|
||||||
|
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
|
||||||
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
|
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False)
|
||||||
|
ref_model = AutoModelForCausalLM.from_pretrained(TINY_LLAMA, torch_dtype=model.dtype, device_map=model.device)
|
||||||
|
compare_model(model, ref_model)
|
@ -2,7 +2,7 @@ import os
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from llamafactory.hparams import get_train_args
|
from llamafactory.hparams import get_infer_args, get_train_args
|
||||||
from llamafactory.model import load_model, load_tokenizer
|
from llamafactory.model import load_model, load_tokenizer
|
||||||
|
|
||||||
|
|
||||||
@ -23,8 +23,15 @@ TRAIN_ARGS = {
|
|||||||
"fp16": True,
|
"fp16": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
INFER_ARGS = {
|
||||||
|
"model_name_or_path": TINY_LLAMA,
|
||||||
|
"finetuning_type": "freeze",
|
||||||
|
"template": "llama3",
|
||||||
|
"infer_dtype": "float16",
|
||||||
|
}
|
||||||
|
|
||||||
def test_freeze_all_modules():
|
|
||||||
|
def test_freeze_train_all_modules():
|
||||||
model_args, _, _, finetuning_args, _ = get_train_args({"freeze_trainable_layers": 1, **TRAIN_ARGS})
|
model_args, _, _, finetuning_args, _ = get_train_args({"freeze_trainable_layers": 1, **TRAIN_ARGS})
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||||
@ -37,7 +44,7 @@ def test_freeze_all_modules():
|
|||||||
assert param.dtype == torch.float16
|
assert param.dtype == torch.float16
|
||||||
|
|
||||||
|
|
||||||
def test_freeze_extra_modules():
|
def test_freeze_train_extra_modules():
|
||||||
model_args, _, _, finetuning_args, _ = get_train_args(
|
model_args, _, _, finetuning_args, _ = get_train_args(
|
||||||
{"freeze_trainable_layers": 1, "freeze_extra_modules": "embed_tokens,lm_head", **TRAIN_ARGS}
|
{"freeze_trainable_layers": 1, "freeze_extra_modules": "embed_tokens,lm_head", **TRAIN_ARGS}
|
||||||
)
|
)
|
||||||
@ -50,3 +57,12 @@ def test_freeze_extra_modules():
|
|||||||
else:
|
else:
|
||||||
assert param.requires_grad is False
|
assert param.requires_grad is False
|
||||||
assert param.dtype == torch.float16
|
assert param.dtype == torch.float16
|
||||||
|
|
||||||
|
|
||||||
|
def test_freeze_inference():
|
||||||
|
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
|
||||||
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
|
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False)
|
||||||
|
for param in model.parameters():
|
||||||
|
assert param.requires_grad is False
|
||||||
|
assert param.dtype == torch.float16
|
||||||
|
@ -2,7 +2,7 @@ import os
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from llamafactory.hparams import get_train_args
|
from llamafactory.hparams import get_infer_args, get_train_args
|
||||||
from llamafactory.model import load_model, load_tokenizer
|
from llamafactory.model import load_model, load_tokenizer
|
||||||
|
|
||||||
|
|
||||||
@ -23,11 +23,27 @@ TRAIN_ARGS = {
|
|||||||
"fp16": True,
|
"fp16": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
INFER_ARGS = {
|
||||||
|
"model_name_or_path": TINY_LLAMA,
|
||||||
|
"finetuning_type": "full",
|
||||||
|
"template": "llama3",
|
||||||
|
"infer_dtype": "float16",
|
||||||
|
}
|
||||||
|
|
||||||
def test_full():
|
|
||||||
|
def test_full_train():
|
||||||
model_args, _, _, finetuning_args, _ = get_train_args(TRAIN_ARGS)
|
model_args, _, _, finetuning_args, _ = get_train_args(TRAIN_ARGS)
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
assert param.requires_grad is True
|
assert param.requires_grad is True
|
||||||
assert param.dtype == torch.float32
|
assert param.dtype == torch.float32
|
||||||
|
|
||||||
|
|
||||||
|
def test_full_inference():
|
||||||
|
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
|
||||||
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
|
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False)
|
||||||
|
for param in model.parameters():
|
||||||
|
assert param.requires_grad is False
|
||||||
|
assert param.dtype == torch.float16
|
||||||
|
@ -1,13 +1,18 @@
|
|||||||
import os
|
import os
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from peft import LoraModel, PeftModel
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
from llamafactory.hparams import get_train_args
|
from llamafactory.hparams import get_infer_args, get_train_args
|
||||||
from llamafactory.model import load_model, load_tokenizer
|
from llamafactory.model import load_model, load_tokenizer
|
||||||
|
|
||||||
|
|
||||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||||
|
|
||||||
|
TINY_LLAMA_ADAPTER = os.environ.get("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-lora")
|
||||||
|
|
||||||
TRAIN_ARGS = {
|
TRAIN_ARGS = {
|
||||||
"model_name_or_path": TINY_LLAMA,
|
"model_name_or_path": TINY_LLAMA,
|
||||||
"stage": "sft",
|
"stage": "sft",
|
||||||
@ -23,8 +28,32 @@ TRAIN_ARGS = {
|
|||||||
"fp16": True,
|
"fp16": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
INFER_ARGS = {
|
||||||
|
"model_name_or_path": TINY_LLAMA,
|
||||||
|
"adapter_name_or_path": TINY_LLAMA_ADAPTER,
|
||||||
|
"finetuning_type": "lora",
|
||||||
|
"template": "llama3",
|
||||||
|
"infer_dtype": "float16",
|
||||||
|
}
|
||||||
|
|
||||||
def test_lora_all_modules():
|
|
||||||
|
def load_reference_model() -> "torch.nn.Module":
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(TINY_LLAMA)
|
||||||
|
return PeftModel.from_pretrained(model, TINY_LLAMA_ADAPTER)
|
||||||
|
|
||||||
|
|
||||||
|
def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_keys: Sequence[str] = []):
|
||||||
|
state_dict_a = model_a.state_dict()
|
||||||
|
state_dict_b = model_b.state_dict()
|
||||||
|
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
|
||||||
|
for name in state_dict_a.keys():
|
||||||
|
if any(key in name for key in diff_keys):
|
||||||
|
assert torch.allclose(state_dict_a[name], state_dict_b[name]) is False
|
||||||
|
else:
|
||||||
|
assert torch.allclose(state_dict_a[name], state_dict_b[name]) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_lora_train_all_modules():
|
||||||
model_args, _, _, finetuning_args, _ = get_train_args({"lora_target": "all", **TRAIN_ARGS})
|
model_args, _, _, finetuning_args, _ = get_train_args({"lora_target": "all", **TRAIN_ARGS})
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||||
@ -41,7 +70,7 @@ def test_lora_all_modules():
|
|||||||
assert linear_modules == {"q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"}
|
assert linear_modules == {"q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"}
|
||||||
|
|
||||||
|
|
||||||
def test_lora_extra_modules():
|
def test_lora_train_extra_modules():
|
||||||
model_args, _, _, finetuning_args, _ = get_train_args(
|
model_args, _, _, finetuning_args, _ = get_train_args(
|
||||||
{"lora_target": "all", "additional_target": "embed_tokens,lm_head", **TRAIN_ARGS}
|
{"lora_target": "all", "additional_target": "embed_tokens,lm_head", **TRAIN_ARGS}
|
||||||
)
|
)
|
||||||
@ -61,3 +90,51 @@ def test_lora_extra_modules():
|
|||||||
assert param.dtype == torch.float16
|
assert param.dtype == torch.float16
|
||||||
|
|
||||||
assert extra_modules == {"embed_tokens", "lm_head"}
|
assert extra_modules == {"embed_tokens", "lm_head"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_lora_train_old_adapters():
|
||||||
|
model_args, _, _, finetuning_args, _ = get_train_args(
|
||||||
|
{"adapter_name_or_path": TINY_LLAMA_ADAPTER, "create_new_adapter": False, **TRAIN_ARGS}
|
||||||
|
)
|
||||||
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
|
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||||
|
|
||||||
|
base_model = AutoModelForCausalLM.from_pretrained(TINY_LLAMA, torch_dtype=model.dtype, device_map=model.device)
|
||||||
|
ref_model = PeftModel.from_pretrained(base_model, TINY_LLAMA_ADAPTER, is_trainable=True)
|
||||||
|
for param in filter(lambda p: p.requires_grad, ref_model.parameters()):
|
||||||
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
|
compare_model(model, ref_model)
|
||||||
|
|
||||||
|
|
||||||
|
def test_lora_train_new_adapters():
|
||||||
|
model_args, _, _, finetuning_args, _ = get_train_args(
|
||||||
|
{"adapter_name_or_path": TINY_LLAMA_ADAPTER, "create_new_adapter": True, **TRAIN_ARGS}
|
||||||
|
)
|
||||||
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
|
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
|
||||||
|
|
||||||
|
base_model = AutoModelForCausalLM.from_pretrained(TINY_LLAMA, torch_dtype=model.dtype, device_map=model.device)
|
||||||
|
ref_model = PeftModel.from_pretrained(base_model, TINY_LLAMA_ADAPTER, is_trainable=True)
|
||||||
|
for param in filter(lambda p: p.requires_grad, ref_model.parameters()):
|
||||||
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
|
compare_model(
|
||||||
|
model, ref_model, diff_keys=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_lora_inference():
|
||||||
|
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
|
||||||
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
|
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False)
|
||||||
|
|
||||||
|
base_model = AutoModelForCausalLM.from_pretrained(TINY_LLAMA, torch_dtype=model.dtype, device_map=model.device)
|
||||||
|
ref_model: "LoraModel" = PeftModel.from_pretrained(base_model, TINY_LLAMA_ADAPTER)
|
||||||
|
ref_model = ref_model.merge_and_unload()
|
||||||
|
compare_model(model, ref_model)
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
assert param.requires_grad is False
|
||||||
|
assert param.dtype == torch.float16
|
||||||
|
assert "lora" not in name
|
||||||
|
Loading…
x
Reference in New Issue
Block a user