diff --git a/requirements.txt b/requirements.txt index ff52658d..bb132738 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ torch>=1.13.1 -transformers>=4.29.1 +transformers>=4.30.0 datasets>=2.12.0 accelerate>=0.21.0 peft==0.4.0 diff --git a/src/llmtuner/extras/callbacks.py b/src/llmtuner/extras/callbacks.py index 61deae25..8d7a1161 100644 --- a/src/llmtuner/extras/callbacks.py +++ b/src/llmtuner/extras/callbacks.py @@ -5,7 +5,9 @@ from typing import TYPE_CHECKING from datetime import timedelta from transformers import TrainerCallback -from transformers.trainer_utils import has_length +from transformers.trainer_callback import TrainerControl, TrainerState +from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR +from transformers.training_args import TrainingArguments from llmtuner.extras.constants import LOG_FILE_NAME from llmtuner.extras.logging import get_logger @@ -17,6 +19,24 @@ if TYPE_CHECKING: logger = get_logger(__name__) +class SavePeftModelCallback(TrainerCallback): + + def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called after a checkpoint save. + """ + output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)) + getattr(kwargs.get("model"), "pretrained_model").save_pretrained(output_dir) + return control + + def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + r""" + Event called at the end of training. + """ + getattr(kwargs.get("model"), "pretrained_model").save_pretrained(args.output_dir) + return control + + class LogCallback(TrainerCallback): def __init__(self, runner=None): diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index f042f76d..0d1694b4 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -2,10 +2,6 @@ IGNORE_INDEX = -100 LOG_FILE_NAME = "trainer_log.jsonl" -VALUE_HEAD_FILE_NAME = "value_head.bin" - -FINETUNING_ARGS_NAME = "finetuning_args.json" - LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] METHODS = ["full", "freeze", "lora"] diff --git a/src/llmtuner/extras/models/flash_llama.py b/src/llmtuner/extras/models/flash_llama.py index 608dbc07..8a4dae2a 100644 --- a/src/llmtuner/extras/models/flash_llama.py +++ b/src/llmtuner/extras/models/flash_llama.py @@ -192,6 +192,7 @@ class FlashRotaryEmbedding(torch.nn.Module): else: assert False + class LlamaMLP(nn.Module): def __init__(self, config): super().__init__() @@ -204,26 +205,7 @@ class LlamaMLP(nn.Module): self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): - if self.config.pretraining_tp > 1: - slice = self.intermediate_size // self.config.pretraining_tp - gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) - up_proj_slices = self.up_proj.weight.split(slice, dim=0) - down_proj_slices = self.down_proj.weight.split(slice, dim=1) - - gate_proj = torch.cat( - [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 - ) - up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) - - intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) - down_proj = [ - F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) - ] - down_proj = sum(down_proj) - else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - return down_proj + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -301,27 +283,9 @@ class LlamaAttention(nn.Module): else: past_len = 0 - if self.config.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - q = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] - q = torch.cat(q, dim=-1) - - k = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] - k = torch.cat(k, dim=-1) - - v = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] - v = torch.cat(v, dim=-1) - - else: - q = self.q_proj(hidden_states) - k = self.k_proj(hidden_states) - v = self.v_proj(hidden_states) + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) q = q.view(bsz, q_len, self.num_heads, self.head_dim) k = k.view(bsz, q_len, self.num_key_value_heads, self.head_dim) @@ -377,12 +341,7 @@ class LlamaAttention(nn.Module): attn_output = attn_output.reshape(bsz, q_len, h_size) attn_weights = attn_outputs[2] if output_attentions else None - if self.config.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) - else: - attn_output = self.o_proj(attn_output) + attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None @@ -703,12 +662,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel): ) hidden_states = outputs[0] - if self.config.pretraining_tp > 1: - lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) - logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] - logits = torch.cat(logits, dim=-1) - else: - logits = self.lm_head(hidden_states) + logits = self.lm_head(hidden_states) logits = logits.float() loss = None diff --git a/src/llmtuner/extras/save_and_load.py b/src/llmtuner/extras/save_and_load.py index af66248d..6d819ce6 100644 --- a/src/llmtuner/extras/save_and_load.py +++ b/src/llmtuner/extras/save_and_load.py @@ -1,49 +1,21 @@ import os import torch -from typing import Dict +from transformers.trainer import WEIGHTS_NAME -from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME -from transformers.modeling_utils import load_sharded_checkpoint - -from llmtuner.extras.constants import VALUE_HEAD_FILE_NAME from llmtuner.extras.logging import get_logger logger = get_logger(__name__) -def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: - state_dict: Dict[str, torch.Tensor] = model.state_dict() - filtered_state_dict = {} - - for k, v in model.named_parameters(): - if v.requires_grad: - filtered_state_dict[k] = state_dict[k].cpu().clone().detach() - - return filtered_state_dict - - -def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool: - weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME) - if os.path.exists(weights_file): - model_state_dict = torch.load(weights_file, map_location="cpu") - model.load_state_dict(model_state_dict, strict=False) # skip missing keys - elif os.path.exists(os.path.join(checkpoint_dir, WEIGHTS_INDEX_NAME)): - load_sharded_checkpoint(model, checkpoint_dir, strict=False) - else: - logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir)) - return False - return True - - def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool: - valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME) - if not os.path.exists(valuehead_file): + vhead_file = os.path.join(checkpoint_dir, WEIGHTS_NAME) + if not os.path.exists(vhead_file): logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir)) return False - valuehead_state_dict = torch.load(valuehead_file, map_location="cpu") - model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"], persistent=False) - model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"], persistent=False) - model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]), persistent=False) - model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]), persistent=False) + vhead_params = torch.load(vhead_file, map_location="cpu") + model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False) + model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False) + model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False) + model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False) return True diff --git a/src/llmtuner/tuner/core/adapter.py b/src/llmtuner/tuner/core/adapter.py index 0324bc74..64d1f485 100644 --- a/src/llmtuner/tuner/core/adapter.py +++ b/src/llmtuner/tuner/core/adapter.py @@ -11,7 +11,6 @@ from peft import ( from peft.utils import CONFIG_NAME, WEIGHTS_NAME from llmtuner.extras.logging import get_logger -from llmtuner.extras.save_and_load import load_trainable_params from llmtuner.tuner.core.utils import find_all_linear_modules if TYPE_CHECKING: @@ -53,9 +52,6 @@ def init_adapter( else: param.data = param.data.to(torch.float32) - if model_args.checkpoint_dir is not None: - assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded." - if finetuning_args.finetuning_type == "lora": logger.info("Fine-tuning method: LoRA") latest_checkpoint = None diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index 5c9fbece..820a5714 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -38,7 +38,7 @@ if TYPE_CHECKING: logger = get_logger(__name__) -check_min_version("4.29.1") +check_min_version("4.30.0") require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") require_version("peft==0.4.0", "To fix: pip install peft==0.4.0") @@ -78,7 +78,7 @@ def load_model_and_tokenizer( if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__): tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) - if finetuning_args.finetuning_type == "full" and model_args.checkpoint_dir is not None: + if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None: model_to_load = model_args.checkpoint_dir[0] else: model_to_load = model_args.model_name_or_path @@ -197,6 +197,7 @@ def load_model_and_tokenizer( # Prepare model with valuehead for RLHF if stage == "rm" or stage == "ppo": model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(model) + model._keys_to_ignore_on_save = None reset_logging() if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.") diff --git a/src/llmtuner/tuner/core/trainer.py b/src/llmtuner/tuner/core/trainer.py deleted file mode 100644 index 9a46d59f..00000000 --- a/src/llmtuner/tuner/core/trainer.py +++ /dev/null @@ -1,118 +0,0 @@ -import os -import torch -from typing import TYPE_CHECKING, Dict, Optional - -from transformers import Seq2SeqTrainer -from transformers.trainer import TRAINING_ARGS_NAME, WEIGHTS_NAME -from transformers.modeling_utils import PreTrainedModel, unwrap_model -from peft import PeftModel -from trl import PreTrainedModelWrapper - -from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME -from llmtuner.extras.logging import get_logger -from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params - -if TYPE_CHECKING: - from transformers import PreTrainedTokenizer, Seq2SeqTrainingArguments, TrainerState - from llmtuner.hparams import FinetuningArguments - - -logger = get_logger(__name__) - - -class PeftModelMixin: - r""" - Patches the save and load methods in Hugging Face Trainer for PeftModel and ModelWithValueHead. - """ - - def __init__(self) -> None: # for type checking - self.model: PreTrainedModel = None - self.tokenizer: "PreTrainedTokenizer" = None - self.args: "Seq2SeqTrainingArguments" = None - self.finetuning_args: "FinetuningArguments" = None - self.state: "TrainerState" = None - raise AssertionError("Mixin should not be initialized.") - - def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None: - r""" - Saves trainable parameters as model checkpoint. - - This function will only be executed at the process zero. - - Subclass and override to inject custom behavior. It should not be directly used by external scripts. - """ - output_dir = output_dir if output_dir is not None else self.args.output_dir - os.makedirs(output_dir, exist_ok=True) - logger.info(f"Saving model checkpoint to {output_dir}") - model = self.model - model_unwrapped = unwrap_model(model) - - if isinstance(model_unwrapped, PreTrainedModelWrapper): - # Custom state dict: https://github.com/lvwerra/trl/blob/v0.7.1/trl/models/modeling_value_head.py#L200 - model_state_dict = state_dict or model.state_dict() - v_head_state_dict = { - name.replace("v_head.", ""): model_state_dict[name].cpu().clone().detach() - for name in model_state_dict.keys() if name.startswith("v_head.") - } - torch.save(v_head_state_dict, os.path.join(output_dir, VALUE_HEAD_FILE_NAME)) - model = model_unwrapped.pretrained_model - model_unwrapped = unwrap_model(model) - - state_dict = state_dict or get_state_dict(model) - if not isinstance(model, (PeftModel, PreTrainedModel)): - if isinstance(model_unwrapped, (PeftModel, PreTrainedModel)): - model_unwrapped.config.use_cache = True - model_unwrapped.save_pretrained( - output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors - ) - model_unwrapped.config.use_cache = False - else: - logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") - torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) - else: - model.config.use_cache = True - model.save_pretrained( - output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors - ) - model.config.use_cache = False - - if self.finetuning_args.finetuning_type == "full" and self.tokenizer is not None: - try: - self.tokenizer.save_pretrained(output_dir) - except: - logger.warning("Cannot save tokenizer, copy the files manually.") - - with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f: - f.write(self.args.to_json_string() + "\n") - - self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME)) - - def _load_best_model(self): - r""" - Loads trainable parameters from model checkpoint. - - Subclass and override to inject custom behavior. It should not be directly used by external scripts. - """ - logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") - model = unwrap_model(self.model) - - if isinstance(model, PreTrainedModelWrapper): - model.v_head.load_state_dict(torch.load( - os.path.join(self.state.best_model_checkpoint, VALUE_HEAD_FILE_NAME), map_location="cpu" - )) - model = model.pretrained_model - - if isinstance(model, PeftModel): - model.load_adapter(self.state.best_model_checkpoint, model.active_adapter) - else: # freeze/full-tuning - load_trainable_params(model, self.state.best_model_checkpoint) - - -class PeftTrainer(PeftModelMixin, Seq2SeqTrainer): - r""" - Inherits Seq2SeqTrainer to support parameter-efficient checkpoints. - """ - - def __init__(self, finetuning_args: "FinetuningArguments", **kwargs): - Seq2SeqTrainer.__init__(self, **kwargs) - self.finetuning_args = finetuning_args diff --git a/src/llmtuner/tuner/dpo/trainer.py b/src/llmtuner/tuner/dpo/trainer.py index 572ce13d..0036fe0f 100644 --- a/src/llmtuner/tuner/dpo/trainer.py +++ b/src/llmtuner/tuner/dpo/trainer.py @@ -6,18 +6,16 @@ from trl import DPOTrainer from trl.trainer.utils import disable_dropout_in_model from llmtuner.extras.constants import IGNORE_INDEX -from llmtuner.tuner.core.trainer import PeftModelMixin if TYPE_CHECKING: from transformers import PreTrainedModel - from llmtuner.hparams import FinetuningArguments -class DPOPeftTrainer(PeftModelMixin, DPOTrainer): +class CustomDPOTrainer(DPOTrainer): def __init__( self, - finetuning_args: "FinetuningArguments", + beta: float, model: Union["PreTrainedModel", torch.nn.Module], ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None, disable_dropout: Optional[bool] = True, @@ -28,12 +26,11 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer): if ref_model is not None: disable_dropout_in_model(ref_model) - self.finetuning_args = finetuning_args self.ref_model = ref_model self.use_dpo_data_collator = True # hack to avoid warning self.label_pad_token_id = IGNORE_INDEX self.padding_value = 0 - self.beta = finetuning_args.dpo_beta + self.beta = beta self._stored_metrics = defaultdict(lambda: defaultdict(list)) Trainer.__init__(self, model=model, **kwargs) diff --git a/src/llmtuner/tuner/dpo/workflow.py b/src/llmtuner/tuner/dpo/workflow.py index 31d82fbf..4abd3894 100644 --- a/src/llmtuner/tuner/dpo/workflow.py +++ b/src/llmtuner/tuner/dpo/workflow.py @@ -10,7 +10,7 @@ from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.ploting import plot_loss from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding -from llmtuner.tuner.dpo.trainer import DPOPeftTrainer +from llmtuner.tuner.dpo.trainer import CustomDPOTrainer if TYPE_CHECKING: from transformers import TrainerCallback @@ -37,10 +37,10 @@ def run_dpo( training_args = Seq2SeqTrainingArguments(**training_args_dict) # Initialize our Trainer - trainer = DPOPeftTrainer( - finetuning_args=finetuning_args, - ref_model=deepcopy(model) if not isinstance(model, PeftModel) else None, + trainer = CustomDPOTrainer( + beta=finetuning_args.dpo_beta, model=model, + ref_model=deepcopy(model) if not isinstance(model, PeftModel) else None, args=training_args, tokenizer=tokenizer, data_collator=data_collator, diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index 981e6d41..49d2d1a9 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -4,27 +4,25 @@ import torch from tqdm import tqdm from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple -from transformers import GenerationConfig, TrainerState, TrainerControl +from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl from trl import PPOTrainer from trl.core import LengthSampler, PPODecorators, logprobs_from_logits from llmtuner.extras.logging import get_logger from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor -from llmtuner.tuner.core.trainer import PeftTrainer from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model if TYPE_CHECKING: - from transformers import Seq2SeqTrainingArguments + from transformers import Seq2SeqTrainingArguments, TrainerCallback from trl import AutoModelForCausalLMWithValueHead - from llmtuner.extras.callbacks import LogCallback - from llmtuner.hparams import FinetuningArguments, GeneratingArguments + from llmtuner.hparams import GeneratingArguments logger = get_logger(__name__) -class PPOPeftTrainer(PPOTrainer, PeftTrainer): +class CustomPPOTrainer(PPOTrainer, Trainer): r""" Inherits PPOTrainer. """ @@ -32,9 +30,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): def __init__( self, training_args: "Seq2SeqTrainingArguments", - finetuning_args: "FinetuningArguments", generating_args: "GeneratingArguments", - callbacks: List["LogCallback"], + callbacks: List["TrainerCallback"], compute_dtype: torch.dtype, **kwargs ): @@ -43,9 +40,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): raise ValueError("PPOTrainer is incompatible with DeepSpeed.") self.args = training_args - self.finetuning_args = finetuning_args self.generating_args = generating_args - self.log_callback = callbacks[0] + self.log_callback, self.save_callback = callbacks[0], callbacks[1] self.compute_dtype = compute_dtype self.state = TrainerState() self.control = TrainerControl() @@ -147,7 +143,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): dataiter = iter(self.dataloader) steps_trained = 0 - self.log_callback.on_train_end(self.args, self.state, self.control) + self.log_callback.on_train_end( + self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model) + ) @torch.no_grad() def get_inputs( @@ -296,3 +294,6 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): """ if self.args.should_save: self._save(output_dir) + self.save_callback.on_save( + self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model) + ) diff --git a/src/llmtuner/tuner/ppo/workflow.py b/src/llmtuner/tuner/ppo/workflow.py index 66daa99c..4cfec75b 100644 --- a/src/llmtuner/tuner/ppo/workflow.py +++ b/src/llmtuner/tuner/ppo/workflow.py @@ -8,9 +8,10 @@ from transformers import DataCollatorWithPadding from transformers.optimization import get_scheduler from llmtuner.dsets import get_dataset, preprocess_dataset +from llmtuner.extras.callbacks import SavePeftModelCallback from llmtuner.extras.ploting import plot_loss from llmtuner.tuner.core import load_model_and_tokenizer -from llmtuner.tuner.ppo.trainer import PPOPeftTrainer +from llmtuner.tuner.ppo.trainer import CustomPPOTrainer if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback @@ -61,11 +62,10 @@ def run_ppo( ) # Initialize our Trainer - ppo_trainer = PPOPeftTrainer( + ppo_trainer = CustomPPOTrainer( training_args=training_args, - finetuning_args=finetuning_args, generating_args=generating_args, - callbacks=callbacks, + callbacks=callbacks + [SavePeftModelCallback()], compute_dtype=model_args.compute_dtype, config=ppo_config, model=model, diff --git a/src/llmtuner/tuner/pt/workflow.py b/src/llmtuner/tuner/pt/workflow.py index 9908dece..66d08de7 100644 --- a/src/llmtuner/tuner/pt/workflow.py +++ b/src/llmtuner/tuner/pt/workflow.py @@ -2,12 +2,11 @@ import math from typing import TYPE_CHECKING, Optional, List -from transformers import DataCollatorForLanguageModeling +from transformers import DataCollatorForLanguageModeling, Trainer from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.ploting import plot_loss from llmtuner.tuner.core import load_model_and_tokenizer -from llmtuner.tuner.core.trainer import PeftTrainer if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback @@ -27,8 +26,7 @@ def run_pt( data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) # Initialize our Trainer - trainer = PeftTrainer( - finetuning_args=finetuning_args, + trainer = Trainer( model=model, args=training_args, tokenizer=tokenizer, diff --git a/src/llmtuner/tuner/rm/trainer.py b/src/llmtuner/tuner/rm/trainer.py index 23b33539..80502937 100644 --- a/src/llmtuner/tuner/rm/trainer.py +++ b/src/llmtuner/tuner/rm/trainer.py @@ -2,9 +2,9 @@ import os import json import torch from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from transformers import Trainer from llmtuner.extras.logging import get_logger -from llmtuner.tuner.core.trainer import PeftTrainer if TYPE_CHECKING: from transformers.trainer import PredictionOutput @@ -14,7 +14,7 @@ if TYPE_CHECKING: logger = get_logger(__name__) -class PairwisePeftTrainer(PeftTrainer): +class PairwiseTrainer(Trainer): r""" Inherits PeftTrainer to compute pairwise loss. """ diff --git a/src/llmtuner/tuner/rm/workflow.py b/src/llmtuner/tuner/rm/workflow.py index 91441f70..edc8e7c5 100644 --- a/src/llmtuner/tuner/rm/workflow.py +++ b/src/llmtuner/tuner/rm/workflow.py @@ -5,11 +5,12 @@ from typing import TYPE_CHECKING, Optional, List from transformers import Seq2SeqTrainingArguments from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset +from llmtuner.extras.callbacks import SavePeftModelCallback from llmtuner.extras.ploting import plot_loss from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.rm.metric import compute_accuracy from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding -from llmtuner.tuner.rm.trainer import PairwisePeftTrainer +from llmtuner.tuner.rm.trainer import PairwiseTrainer if TYPE_CHECKING: from transformers import TrainerCallback @@ -33,13 +34,12 @@ def run_rm( training_args = Seq2SeqTrainingArguments(**training_args_dict) # Initialize our Trainer - trainer = PairwisePeftTrainer( - finetuning_args=finetuning_args, + trainer = PairwiseTrainer( model=model, args=training_args, tokenizer=tokenizer, data_collator=data_collator, - callbacks=callbacks, + callbacks=callbacks + [SavePeftModelCallback()], compute_metrics=compute_accuracy, **split_dataset(dataset, data_args, training_args) ) diff --git a/src/llmtuner/tuner/sft/trainer.py b/src/llmtuner/tuner/sft/trainer.py index db8878f6..4fafc76b 100644 --- a/src/llmtuner/tuner/sft/trainer.py +++ b/src/llmtuner/tuner/sft/trainer.py @@ -4,10 +4,10 @@ import torch import numpy as np import torch.nn as nn from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from transformers import Seq2SeqTrainer from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.logging import get_logger -from llmtuner.tuner.core.trainer import PeftTrainer if TYPE_CHECKING: from transformers.trainer import PredictionOutput @@ -16,7 +16,7 @@ if TYPE_CHECKING: logger = get_logger(__name__) -class Seq2SeqPeftTrainer(PeftTrainer): +class CustomSeq2SeqTrainer(Seq2SeqTrainer): r""" Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE. """ diff --git a/src/llmtuner/tuner/sft/workflow.py b/src/llmtuner/tuner/sft/workflow.py index 2ae86fbd..05942780 100644 --- a/src/llmtuner/tuner/sft/workflow.py +++ b/src/llmtuner/tuner/sft/workflow.py @@ -9,7 +9,7 @@ from llmtuner.extras.misc import get_logits_processor from llmtuner.extras.ploting import plot_loss from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.sft.metric import ComputeMetrics -from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer +from llmtuner.tuner.sft.trainer import CustomSeq2SeqTrainer if TYPE_CHECKING: from transformers import TrainerCallback @@ -45,8 +45,7 @@ def run_sft( training_args = Seq2SeqTrainingArguments(**training_args_dict) # Initialize our Trainer - trainer = Seq2SeqPeftTrainer( - finetuning_args=finetuning_args, + trainer = CustomSeq2SeqTrainer( model=model, args=training_args, tokenizer=tokenizer,