mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 20:00:36 +08:00
remove PeftTrainer
This commit is contained in:
@@ -5,7 +5,9 @@ from typing import TYPE_CHECKING
|
||||
from datetime import timedelta
|
||||
|
||||
from transformers import TrainerCallback
|
||||
from transformers.trainer_utils import has_length
|
||||
from transformers.trainer_callback import TrainerControl, TrainerState
|
||||
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
|
||||
from transformers.training_args import TrainingArguments
|
||||
|
||||
from llmtuner.extras.constants import LOG_FILE_NAME
|
||||
from llmtuner.extras.logging import get_logger
|
||||
@@ -17,6 +19,24 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SavePeftModelCallback(TrainerCallback):
|
||||
|
||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after a checkpoint save.
|
||||
"""
|
||||
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
||||
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(output_dir)
|
||||
return control
|
||||
|
||||
def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
r"""
|
||||
Event called at the end of training.
|
||||
"""
|
||||
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(args.output_dir)
|
||||
return control
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
|
||||
def __init__(self, runner=None):
|
||||
|
||||
@@ -2,10 +2,6 @@ IGNORE_INDEX = -100
|
||||
|
||||
LOG_FILE_NAME = "trainer_log.jsonl"
|
||||
|
||||
VALUE_HEAD_FILE_NAME = "value_head.bin"
|
||||
|
||||
FINETUNING_ARGS_NAME = "finetuning_args.json"
|
||||
|
||||
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"]
|
||||
|
||||
METHODS = ["full", "freeze", "lora"]
|
||||
|
||||
@@ -192,6 +192,7 @@ class FlashRotaryEmbedding(torch.nn.Module):
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
@@ -204,26 +205,7 @@ class LlamaMLP(nn.Module):
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
if self.config.pretraining_tp > 1:
|
||||
slice = self.intermediate_size // self.config.pretraining_tp
|
||||
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
|
||||
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
|
||||
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
|
||||
|
||||
gate_proj = torch.cat(
|
||||
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
|
||||
)
|
||||
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
|
||||
|
||||
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
|
||||
down_proj = [
|
||||
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
|
||||
]
|
||||
down_proj = sum(down_proj)
|
||||
else:
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
return down_proj
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
@@ -301,27 +283,9 @@ class LlamaAttention(nn.Module):
|
||||
else:
|
||||
past_len = 0
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
||||
query_slices = self.q_proj.weight.split(
|
||||
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
||||
)
|
||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||
|
||||
q = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
k = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
k = torch.cat(k, dim=-1)
|
||||
|
||||
v = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
v = torch.cat(v, dim=-1)
|
||||
|
||||
else:
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
|
||||
q = q.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
k = k.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
||||
@@ -377,12 +341,7 @@ class LlamaAttention(nn.Module):
|
||||
attn_output = attn_output.reshape(bsz, q_len, h_size)
|
||||
attn_weights = attn_outputs[2] if output_attentions else None
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
||||
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
||||
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
@@ -703,12 +662,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if self.config.pretraining_tp > 1:
|
||||
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
||||
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
logits = torch.cat(logits, dim=-1)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
|
||||
@@ -1,49 +1,21 @@
|
||||
import os
|
||||
import torch
|
||||
from typing import Dict
|
||||
from transformers.trainer import WEIGHTS_NAME
|
||||
|
||||
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
||||
from transformers.modeling_utils import load_sharded_checkpoint
|
||||
|
||||
from llmtuner.extras.constants import VALUE_HEAD_FILE_NAME
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
||||
state_dict: Dict[str, torch.Tensor] = model.state_dict()
|
||||
filtered_state_dict = {}
|
||||
|
||||
for k, v in model.named_parameters():
|
||||
if v.requires_grad:
|
||||
filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
|
||||
|
||||
return filtered_state_dict
|
||||
|
||||
|
||||
def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
|
||||
weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
|
||||
if os.path.exists(weights_file):
|
||||
model_state_dict = torch.load(weights_file, map_location="cpu")
|
||||
model.load_state_dict(model_state_dict, strict=False) # skip missing keys
|
||||
elif os.path.exists(os.path.join(checkpoint_dir, WEIGHTS_INDEX_NAME)):
|
||||
load_sharded_checkpoint(model, checkpoint_dir, strict=False)
|
||||
else:
|
||||
logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir))
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
|
||||
valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME)
|
||||
if not os.path.exists(valuehead_file):
|
||||
vhead_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
|
||||
if not os.path.exists(vhead_file):
|
||||
logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
|
||||
return False
|
||||
valuehead_state_dict = torch.load(valuehead_file, map_location="cpu")
|
||||
model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"], persistent=False)
|
||||
model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"], persistent=False)
|
||||
model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]), persistent=False)
|
||||
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]), persistent=False)
|
||||
vhead_params = torch.load(vhead_file, map_location="cpu")
|
||||
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
|
||||
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
|
||||
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
|
||||
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user