remove PeftTrainer

Former-commit-id: b218c271edfb07006ddc34b1aca404088de6c528
This commit is contained in:
hiyouga 2023-09-10 22:23:23 +08:00
parent cf08bcf3d9
commit 6a71361a54
17 changed files with 75 additions and 259 deletions

View File

@ -1,5 +1,5 @@
torch>=1.13.1 torch>=1.13.1
transformers>=4.29.1 transformers>=4.30.0
datasets>=2.12.0 datasets>=2.12.0
accelerate>=0.21.0 accelerate>=0.21.0
peft==0.4.0 peft==0.4.0

View File

@ -5,7 +5,9 @@ from typing import TYPE_CHECKING
from datetime import timedelta from datetime import timedelta
from transformers import TrainerCallback from transformers import TrainerCallback
from transformers.trainer_utils import has_length from transformers.trainer_callback import TrainerControl, TrainerState
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
from transformers.training_args import TrainingArguments
from llmtuner.extras.constants import LOG_FILE_NAME from llmtuner.extras.constants import LOG_FILE_NAME
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
@ -17,6 +19,24 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
class SavePeftModelCallback(TrainerCallback):
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a checkpoint save.
"""
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(output_dir)
return control
def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
r"""
Event called at the end of training.
"""
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(args.output_dir)
return control
class LogCallback(TrainerCallback): class LogCallback(TrainerCallback):
def __init__(self, runner=None): def __init__(self, runner=None):

View File

@ -2,10 +2,6 @@ IGNORE_INDEX = -100
LOG_FILE_NAME = "trainer_log.jsonl" LOG_FILE_NAME = "trainer_log.jsonl"
VALUE_HEAD_FILE_NAME = "value_head.bin"
FINETUNING_ARGS_NAME = "finetuning_args.json"
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"]
METHODS = ["full", "freeze", "lora"] METHODS = ["full", "freeze", "lora"]

View File

@ -192,6 +192,7 @@ class FlashRotaryEmbedding(torch.nn.Module):
else: else:
assert False assert False
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
@ -204,26 +205,7 @@ class LlamaMLP(nn.Module):
self.act_fn = ACT2FN[config.hidden_act] self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x): def forward(self, x):
if self.config.pretraining_tp > 1: return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
slice = self.intermediate_size // self.config.pretraining_tp
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
gate_proj = torch.cat(
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
)
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
down_proj = [
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
]
down_proj = sum(down_proj)
else:
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@ -301,27 +283,9 @@ class LlamaAttention(nn.Module):
else: else:
past_len = 0 past_len = 0
if self.config.pretraining_tp > 1: q = self.q_proj(hidden_states)
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp k = self.k_proj(hidden_states)
query_slices = self.q_proj.weight.split( v = self.v_proj(hidden_states)
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
q = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
q = torch.cat(q, dim=-1)
k = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
k = torch.cat(k, dim=-1)
v = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
v = torch.cat(v, dim=-1)
else:
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q = q.view(bsz, q_len, self.num_heads, self.head_dim) q = q.view(bsz, q_len, self.num_heads, self.head_dim)
k = k.view(bsz, q_len, self.num_key_value_heads, self.head_dim) k = k.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
@ -377,12 +341,7 @@ class LlamaAttention(nn.Module):
attn_output = attn_output.reshape(bsz, q_len, h_size) attn_output = attn_output.reshape(bsz, q_len, h_size)
attn_weights = attn_outputs[2] if output_attentions else None attn_weights = attn_outputs[2] if output_attentions else None
if self.config.pretraining_tp > 1: attn_output = self.o_proj(attn_output)
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
if not output_attentions: if not output_attentions:
attn_weights = None attn_weights = None
@ -703,12 +662,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
) )
hidden_states = outputs[0] hidden_states = outputs[0]
if self.config.pretraining_tp > 1: logits = self.lm_head(hidden_states)
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float() logits = logits.float()
loss = None loss = None

View File

@ -1,49 +1,21 @@
import os import os
import torch import torch
from typing import Dict from transformers.trainer import WEIGHTS_NAME
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from transformers.modeling_utils import load_sharded_checkpoint
from llmtuner.extras.constants import VALUE_HEAD_FILE_NAME
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]:
state_dict: Dict[str, torch.Tensor] = model.state_dict()
filtered_state_dict = {}
for k, v in model.named_parameters():
if v.requires_grad:
filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
return filtered_state_dict
def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
if os.path.exists(weights_file):
model_state_dict = torch.load(weights_file, map_location="cpu")
model.load_state_dict(model_state_dict, strict=False) # skip missing keys
elif os.path.exists(os.path.join(checkpoint_dir, WEIGHTS_INDEX_NAME)):
load_sharded_checkpoint(model, checkpoint_dir, strict=False)
else:
logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir))
return False
return True
def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool: def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME) vhead_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
if not os.path.exists(valuehead_file): if not os.path.exists(vhead_file):
logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir)) logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
return False return False
valuehead_state_dict = torch.load(valuehead_file, map_location="cpu") vhead_params = torch.load(vhead_file, map_location="cpu")
model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"], persistent=False) model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"], persistent=False) model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]), persistent=False) model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]), persistent=False) model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
return True return True

View File

@ -11,7 +11,6 @@ from peft import (
from peft.utils import CONFIG_NAME, WEIGHTS_NAME from peft.utils import CONFIG_NAME, WEIGHTS_NAME
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.extras.save_and_load import load_trainable_params
from llmtuner.tuner.core.utils import find_all_linear_modules from llmtuner.tuner.core.utils import find_all_linear_modules
if TYPE_CHECKING: if TYPE_CHECKING:
@ -53,9 +52,6 @@ def init_adapter(
else: else:
param.data = param.data.to(torch.float32) param.data = param.data.to(torch.float32)
if model_args.checkpoint_dir is not None:
assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded."
if finetuning_args.finetuning_type == "lora": if finetuning_args.finetuning_type == "lora":
logger.info("Fine-tuning method: LoRA") logger.info("Fine-tuning method: LoRA")
latest_checkpoint = None latest_checkpoint = None

View File

@ -38,7 +38,7 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
check_min_version("4.29.1") check_min_version("4.30.0")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft==0.4.0", "To fix: pip install peft==0.4.0") require_version("peft==0.4.0", "To fix: pip install peft==0.4.0")
@ -78,7 +78,7 @@ def load_model_and_tokenizer(
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__): if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
if finetuning_args.finetuning_type == "full" and model_args.checkpoint_dir is not None: if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
model_to_load = model_args.checkpoint_dir[0] model_to_load = model_args.checkpoint_dir[0]
else: else:
model_to_load = model_args.model_name_or_path model_to_load = model_args.model_name_or_path
@ -197,6 +197,7 @@ def load_model_and_tokenizer(
# Prepare model with valuehead for RLHF # Prepare model with valuehead for RLHF
if stage == "rm" or stage == "ppo": if stage == "rm" or stage == "ppo":
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(model) model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(model)
model._keys_to_ignore_on_save = None
reset_logging() reset_logging()
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.") logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")

View File

@ -1,118 +0,0 @@
import os
import torch
from typing import TYPE_CHECKING, Dict, Optional
from transformers import Seq2SeqTrainer
from transformers.trainer import TRAINING_ARGS_NAME, WEIGHTS_NAME
from transformers.modeling_utils import PreTrainedModel, unwrap_model
from peft import PeftModel
from trl import PreTrainedModelWrapper
from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME
from llmtuner.extras.logging import get_logger
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, Seq2SeqTrainingArguments, TrainerState
from llmtuner.hparams import FinetuningArguments
logger = get_logger(__name__)
class PeftModelMixin:
r"""
Patches the save and load methods in Hugging Face Trainer for PeftModel and ModelWithValueHead.
"""
def __init__(self) -> None: # for type checking
self.model: PreTrainedModel = None
self.tokenizer: "PreTrainedTokenizer" = None
self.args: "Seq2SeqTrainingArguments" = None
self.finetuning_args: "FinetuningArguments" = None
self.state: "TrainerState" = None
raise AssertionError("Mixin should not be initialized.")
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
r"""
Saves trainable parameters as model checkpoint.
This function will only be executed at the process zero.
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
"""
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
model = self.model
model_unwrapped = unwrap_model(model)
if isinstance(model_unwrapped, PreTrainedModelWrapper):
# Custom state dict: https://github.com/lvwerra/trl/blob/v0.7.1/trl/models/modeling_value_head.py#L200
model_state_dict = state_dict or model.state_dict()
v_head_state_dict = {
name.replace("v_head.", ""): model_state_dict[name].cpu().clone().detach()
for name in model_state_dict.keys() if name.startswith("v_head.")
}
torch.save(v_head_state_dict, os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
model = model_unwrapped.pretrained_model
model_unwrapped = unwrap_model(model)
state_dict = state_dict or get_state_dict(model)
if not isinstance(model, (PeftModel, PreTrainedModel)):
if isinstance(model_unwrapped, (PeftModel, PreTrainedModel)):
model_unwrapped.config.use_cache = True
model_unwrapped.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
model_unwrapped.config.use_cache = False
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
model.config.use_cache = True
model.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
model.config.use_cache = False
if self.finetuning_args.finetuning_type == "full" and self.tokenizer is not None:
try:
self.tokenizer.save_pretrained(output_dir)
except:
logger.warning("Cannot save tokenizer, copy the files manually.")
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
f.write(self.args.to_json_string() + "\n")
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
def _load_best_model(self):
r"""
Loads trainable parameters from model checkpoint.
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
"""
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
model = unwrap_model(self.model)
if isinstance(model, PreTrainedModelWrapper):
model.v_head.load_state_dict(torch.load(
os.path.join(self.state.best_model_checkpoint, VALUE_HEAD_FILE_NAME), map_location="cpu"
))
model = model.pretrained_model
if isinstance(model, PeftModel):
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
else: # freeze/full-tuning
load_trainable_params(model, self.state.best_model_checkpoint)
class PeftTrainer(PeftModelMixin, Seq2SeqTrainer):
r"""
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
"""
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
Seq2SeqTrainer.__init__(self, **kwargs)
self.finetuning_args = finetuning_args

View File

@ -6,18 +6,16 @@ from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model from trl.trainer.utils import disable_dropout_in_model
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.tuner.core.trainer import PeftModelMixin
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedModel from transformers import PreTrainedModel
from llmtuner.hparams import FinetuningArguments
class DPOPeftTrainer(PeftModelMixin, DPOTrainer): class CustomDPOTrainer(DPOTrainer):
def __init__( def __init__(
self, self,
finetuning_args: "FinetuningArguments", beta: float,
model: Union["PreTrainedModel", torch.nn.Module], model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None, ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
disable_dropout: Optional[bool] = True, disable_dropout: Optional[bool] = True,
@ -28,12 +26,11 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
if ref_model is not None: if ref_model is not None:
disable_dropout_in_model(ref_model) disable_dropout_in_model(ref_model)
self.finetuning_args = finetuning_args
self.ref_model = ref_model self.ref_model = ref_model
self.use_dpo_data_collator = True # hack to avoid warning self.use_dpo_data_collator = True # hack to avoid warning
self.label_pad_token_id = IGNORE_INDEX self.label_pad_token_id = IGNORE_INDEX
self.padding_value = 0 self.padding_value = 0
self.beta = finetuning_args.dpo_beta self.beta = beta
self._stored_metrics = defaultdict(lambda: defaultdict(list)) self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, model=model, **kwargs) Trainer.__init__(self, model=model, **kwargs)

View File

@ -10,7 +10,7 @@ from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
from llmtuner.tuner.dpo.trainer import DPOPeftTrainer from llmtuner.tuner.dpo.trainer import CustomDPOTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainerCallback from transformers import TrainerCallback
@ -37,10 +37,10 @@ def run_dpo(
training_args = Seq2SeqTrainingArguments(**training_args_dict) training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer # Initialize our Trainer
trainer = DPOPeftTrainer( trainer = CustomDPOTrainer(
finetuning_args=finetuning_args, beta=finetuning_args.dpo_beta,
ref_model=deepcopy(model) if not isinstance(model, PeftModel) else None,
model=model, model=model,
ref_model=deepcopy(model) if not isinstance(model, PeftModel) else None,
args=training_args, args=training_args,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,

View File

@ -4,27 +4,25 @@ import torch
from tqdm import tqdm from tqdm import tqdm
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
from transformers import GenerationConfig, TrainerState, TrainerControl from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
from trl import PPOTrainer from trl import PPOTrainer
from trl.core import LengthSampler, PPODecorators, logprobs_from_logits from trl.core import LengthSampler, PPODecorators, logprobs_from_logits
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
from llmtuner.tuner.core.trainer import PeftTrainer
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.callbacks import LogCallback from llmtuner.hparams import GeneratingArguments
from llmtuner.hparams import FinetuningArguments, GeneratingArguments
logger = get_logger(__name__) logger = get_logger(__name__)
class PPOPeftTrainer(PPOTrainer, PeftTrainer): class CustomPPOTrainer(PPOTrainer, Trainer):
r""" r"""
Inherits PPOTrainer. Inherits PPOTrainer.
""" """
@ -32,9 +30,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
def __init__( def __init__(
self, self,
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
callbacks: List["LogCallback"], callbacks: List["TrainerCallback"],
compute_dtype: torch.dtype, compute_dtype: torch.dtype,
**kwargs **kwargs
): ):
@ -43,9 +40,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
raise ValueError("PPOTrainer is incompatible with DeepSpeed.") raise ValueError("PPOTrainer is incompatible with DeepSpeed.")
self.args = training_args self.args = training_args
self.finetuning_args = finetuning_args
self.generating_args = generating_args self.generating_args = generating_args
self.log_callback = callbacks[0] self.log_callback, self.save_callback = callbacks[0], callbacks[1]
self.compute_dtype = compute_dtype self.compute_dtype = compute_dtype
self.state = TrainerState() self.state = TrainerState()
self.control = TrainerControl() self.control = TrainerControl()
@ -147,7 +143,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
dataiter = iter(self.dataloader) dataiter = iter(self.dataloader)
steps_trained = 0 steps_trained = 0
self.log_callback.on_train_end(self.args, self.state, self.control) self.log_callback.on_train_end(
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
)
@torch.no_grad() @torch.no_grad()
def get_inputs( def get_inputs(
@ -296,3 +294,6 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
""" """
if self.args.should_save: if self.args.should_save:
self._save(output_dir) self._save(output_dir)
self.save_callback.on_save(
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
)

View File

@ -8,9 +8,10 @@ from transformers import DataCollatorWithPadding
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
from llmtuner.dsets import get_dataset, preprocess_dataset from llmtuner.dsets import get_dataset, preprocess_dataset
from llmtuner.extras.callbacks import SavePeftModelCallback
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.ppo.trainer import PPOPeftTrainer from llmtuner.tuner.ppo.trainer import CustomPPOTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback from transformers import Seq2SeqTrainingArguments, TrainerCallback
@ -61,11 +62,10 @@ def run_ppo(
) )
# Initialize our Trainer # Initialize our Trainer
ppo_trainer = PPOPeftTrainer( ppo_trainer = CustomPPOTrainer(
training_args=training_args, training_args=training_args,
finetuning_args=finetuning_args,
generating_args=generating_args, generating_args=generating_args,
callbacks=callbacks, callbacks=callbacks + [SavePeftModelCallback()],
compute_dtype=model_args.compute_dtype, compute_dtype=model_args.compute_dtype,
config=ppo_config, config=ppo_config,
model=model, model=model,

View File

@ -2,12 +2,11 @@
import math import math
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForLanguageModeling from transformers import DataCollatorForLanguageModeling, Trainer
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback from transformers import Seq2SeqTrainingArguments, TrainerCallback
@ -27,8 +26,7 @@ def run_pt(
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Initialize our Trainer # Initialize our Trainer
trainer = PeftTrainer( trainer = Trainer(
finetuning_args=finetuning_args,
model=model, model=model,
args=training_args, args=training_args,
tokenizer=tokenizer, tokenizer=tokenizer,

View File

@ -2,9 +2,9 @@ import os
import json import json
import torch import torch
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from transformers import Trainer
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.trainer import PredictionOutput from transformers.trainer import PredictionOutput
@ -14,7 +14,7 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
class PairwisePeftTrainer(PeftTrainer): class PairwiseTrainer(Trainer):
r""" r"""
Inherits PeftTrainer to compute pairwise loss. Inherits PeftTrainer to compute pairwise loss.
""" """

View File

@ -5,11 +5,12 @@ from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import SavePeftModelCallback
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.rm.metric import compute_accuracy from llmtuner.tuner.rm.metric import compute_accuracy
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
from llmtuner.tuner.rm.trainer import PairwisePeftTrainer from llmtuner.tuner.rm.trainer import PairwiseTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainerCallback from transformers import TrainerCallback
@ -33,13 +34,12 @@ def run_rm(
training_args = Seq2SeqTrainingArguments(**training_args_dict) training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer # Initialize our Trainer
trainer = PairwisePeftTrainer( trainer = PairwiseTrainer(
finetuning_args=finetuning_args,
model=model, model=model,
args=training_args, args=training_args,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks + [SavePeftModelCallback()],
compute_metrics=compute_accuracy, compute_metrics=compute_accuracy,
**split_dataset(dataset, data_args, training_args) **split_dataset(dataset, data_args, training_args)
) )

View File

@ -4,10 +4,10 @@ import torch
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from transformers import Seq2SeqTrainer
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.trainer import PredictionOutput from transformers.trainer import PredictionOutput
@ -16,7 +16,7 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
class Seq2SeqPeftTrainer(PeftTrainer): class CustomSeq2SeqTrainer(Seq2SeqTrainer):
r""" r"""
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE. Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
""" """

View File

@ -9,7 +9,7 @@ from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.sft.metric import ComputeMetrics from llmtuner.tuner.sft.metric import ComputeMetrics
from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer from llmtuner.tuner.sft.trainer import CustomSeq2SeqTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainerCallback from transformers import TrainerCallback
@ -45,8 +45,7 @@ def run_sft(
training_args = Seq2SeqTrainingArguments(**training_args_dict) training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer # Initialize our Trainer
trainer = Seq2SeqPeftTrainer( trainer = CustomSeq2SeqTrainer(
finetuning_args=finetuning_args,
model=model, model=model,
args=training_args, args=training_args,
tokenizer=tokenizer, tokenizer=tokenizer,