From 3d40bdb600f3c99ff9d07b34ede742d403de472d Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 7 Nov 2023 16:13:36 +0800 Subject: [PATCH] upgrade peft, fix #1088 #1411 Former-commit-id: b2a60905f384ada92618bf21301fe96dac1c10bf --- requirements.txt | 2 +- src/llmtuner/dsets/loader.py | 4 +- src/llmtuner/dsets/preprocess.py | 2 +- src/llmtuner/extras/constants.py | 2 +- src/llmtuner/hparams/finetuning_args.py | 8 +-- src/llmtuner/hparams/model_args.py | 10 +--- src/llmtuner/tuner/core/__init__.py | 1 + src/llmtuner/tuner/core/adapter.py | 35 +++++++++++-- src/llmtuner/tuner/core/loader.py | 30 ++++------- src/llmtuner/tuner/core/parser.py | 28 +++++------ src/llmtuner/tuner/core/utils.py | 67 +++++++++++++++---------- src/llmtuner/tuner/dpo/workflow.py | 9 +++- src/llmtuner/tuner/pt/workflow.py | 11 ++-- src/llmtuner/tuner/rm/workflow.py | 12 +++-- src/llmtuner/tuner/sft/workflow.py | 11 ++-- 15 files changed, 133 insertions(+), 99 deletions(-) diff --git a/requirements.txt b/requirements.txt index 840d2f2d..790dce6a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ torch>=1.13.1 transformers>=4.31.0,<4.35.0 datasets>=2.12.0 accelerate>=0.21.0 -peft>=0.4.0 +peft>=0.6.0 trl>=0.7.2 gradio>=3.38.0,<4.0.0 scipy diff --git a/src/llmtuner/dsets/loader.py b/src/llmtuner/dsets/loader.py index 834ef733..98d495e9 100644 --- a/src/llmtuner/dsets/loader.py +++ b/src/llmtuner/dsets/loader.py @@ -59,8 +59,8 @@ def get_dataset( data_files=data_files, split=data_args.split, cache_dir=model_args.cache_dir, - streaming=data_args.streaming, - use_auth_token=True if model_args.use_auth_token else None + token=model_args.hf_hub_token, + streaming=data_args.streaming ) if max_samples is not None: # truncate dataset diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index 0484b78e..b331de35 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -257,7 +257,7 @@ def preprocess_dataset( if data_args.cache_path is not None and not os.path.exists(data_args.cache_path): if training_args.should_save: dataset.save_to_disk(data_args.cache_path) - raise SystemExit("Dataset saved, rerun this script with the same `--cache_file`.") + raise SystemExit("Dataset saved, rerun this script with the same `--cache_path`.") if training_args.should_log: try: diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index dc55a080..26b86579 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -2,7 +2,7 @@ IGNORE_INDEX = -100 LOG_FILE_NAME = "trainer_log.jsonl" -LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp", "ln_1", "ln_2"] +LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp", "ln_1", "ln_2", "ln1", "ln2"] METHODS = ["full", "freeze", "lora"] diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index d5ef323d..cf7608e0 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -24,10 +24,10 @@ class FinetuningArguments: default="mlp", metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \ LLaMA choices: [\"mlp\", \"self_attn\"], \ - BLOOM & Falcon & ChatGLM2 choices: [\"mlp\", \"self_attention\"], \ + BLOOM & Falcon & ChatGLM choices: [\"mlp\", \"self_attention\"], \ Qwen choices: [\"mlp\", \"attn\"], \ Phi-1.5 choices: [\"mlp\", \"mixer\"], \ - LLaMA-2, Baichuan, InternLM, XVERSE choices: the same as LLaMA."} + LLaMA-2, BlueLM, Baichuan, InternLM, Mistral, Skywork, XVERSE, Yi choices: the same as LLaMA."} ) lora_rank: Optional[int] = field( default=8, @@ -45,11 +45,11 @@ class FinetuningArguments: default=None, metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \ LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ - BLOOM & Falcon & ChatGLM2 choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \ + BLOOM & Falcon & ChatGLM choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \ Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \ Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \ - LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."} + LLaMA-2, BlueLM, InternLM, Mistral, Skywork, XVERSE, Yi choices: the same as LLaMA."} ) additional_target: Optional[str] = field( default=None, diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index 7c25fad1..e14f55de 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -22,10 +22,6 @@ class ModelArguments: default=False, metadata={"help": "Whether or not the special tokens should be split during the tokenization process."} ) - use_auth_token: Optional[bool] = field( - default=False, - metadata={"help": "Will use the token generated when running `huggingface-cli login`."} - ) model_revision: Optional[str] = field( default="main", metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."} @@ -66,7 +62,7 @@ class ModelArguments: default=False, metadata={"help": "Whether to plot the training loss after fine-tuning or not."} ) - hf_auth_token: Optional[str] = field( + hf_hub_token: Optional[str] = field( default=None, metadata={"help": "Auth token to log in with Hugging Face Hub."} ) @@ -87,7 +83,3 @@ class ModelArguments: if self.quantization_bit is not None: assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization." - - if self.use_auth_token == True and self.hf_auth_token is not None: - from huggingface_hub.hf_api import HfFolder # lazy load - HfFolder.save_token(self.hf_auth_token) diff --git a/src/llmtuner/tuner/core/__init__.py b/src/llmtuner/tuner/core/__init__.py index bd1c5cf0..ac621f7c 100644 --- a/src/llmtuner/tuner/core/__init__.py +++ b/src/llmtuner/tuner/core/__init__.py @@ -1,2 +1,3 @@ from llmtuner.tuner.core.parser import get_train_args, get_infer_args from llmtuner.tuner.core.loader import load_model_and_tokenizer +from llmtuner.tuner.core.utils import generate_model_card diff --git a/src/llmtuner/tuner/core/adapter.py b/src/llmtuner/tuner/core/adapter.py index 4fcc6e62..25330545 100644 --- a/src/llmtuner/tuner/core/adapter.py +++ b/src/llmtuner/tuner/core/adapter.py @@ -1,6 +1,9 @@ +import os import torch from typing import TYPE_CHECKING +from transformers.utils import cached_file +from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME from peft import ( PeftModel, TaskType, @@ -23,8 +26,7 @@ def init_adapter( model: "PreTrainedModel", model_args: "ModelArguments", finetuning_args: "FinetuningArguments", - is_trainable: bool, - is_mergeable: bool + is_trainable: bool ) -> "PreTrainedModel": r""" Initializes the adapters. @@ -61,7 +63,7 @@ def init_adapter( latest_checkpoint = None if model_args.checkpoint_dir is not None: - if (is_trainable and finetuning_args.resume_lora_training) or (not is_mergeable): # continually fine-tuning + if is_trainable and finetuning_args.resume_lora_training: # continually fine-tuning checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] else: checkpoints_to_merge = model_args.checkpoint_dir @@ -92,10 +94,33 @@ def init_adapter( modules_to_save=finetuning_args.additional_target ) model = get_peft_model(model, lora_config) - if id(model.peft_config) != id(model.base_model.peft_config): # https://github.com/huggingface/peft/issues/923 - model.base_model.peft_config = model.peft_config if model_args.checkpoint_dir is not None: logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir))) return model + + +def load_valuehead_params( + model: "PreTrainedModel", + model_args: "ModelArguments" +) -> None: + kwargs = { + "path_or_repo_id": model_args.reward_model, + "cache_dir": model_args.cache_dir, + "token": model_args.hf_hub_token, + "revision": model_args.model_revision + } + try: + vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs) + except: + try: + vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs) + except: + raise ValueError("Provided path ({}) does not contain valuehead weights.".format(model_args.reward_model)) + + vhead_params = torch.load(vhead_file, map_location="cpu") + model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False) + model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False) + model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False) + model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False) diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index e77c4945..7cd49e79 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -25,9 +25,8 @@ except ImportError: # https://github.com/huggingface/transformers/releases/tag/v from llmtuner.extras.logging import reset_logging, get_logger from llmtuner.extras.misc import count_parameters, infer_optim_dtype from llmtuner.extras.patches import llama_patch as LlamaPatches -from llmtuner.extras.save_and_load import load_valuehead_params from llmtuner.hparams import FinetuningArguments -from llmtuner.tuner.core.adapter import init_adapter +from llmtuner.tuner.core.adapter import init_adapter, load_valuehead_params from llmtuner.tuner.core.utils import prepare_model_for_training if TYPE_CHECKING: @@ -41,7 +40,7 @@ logger = get_logger(__name__) require_version("transformers>=4.31.0,<4.35.0", "To fix: pip install \"transformers>=4.31.0,<4.35.0\"") require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") -require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0") +require_version("peft>=0.6.0", "To fix: pip install peft>=0.6.0") require_version("trl>=0.7.2", "To fix: pip install trl>=0.7.2") @@ -64,7 +63,7 @@ def load_model_and_tokenizer( "trust_remote_code": True, "cache_dir": model_args.cache_dir, "revision": model_args.model_revision, - "use_auth_token": True if model_args.use_auth_token else None, + "token": model_args.hf_hub_token } tokenizer = AutoTokenizer.from_pretrained( @@ -99,15 +98,9 @@ def load_model_and_tokenizer( # Set RoPE scaling if model_args.rope_scaling is not None: - if hasattr(config, "use_dynamic_ntk"): # for Qwen models - if is_trainable: - logger.warning("Qwen model does not support RoPE scaling in training.") - else: - setattr(config, "use_dynamic_ntk", True) - setattr(config, "use_logn_attn", True) - logger.info("Using dynamic NTK scaling.") - - elif hasattr(config, "rope_scaling"): # for LLaMA and Falcon models + if not hasattr(config, "rope_scaling"): + logger.warning("Current model does not support RoPE scaling.") + else: if is_trainable: if model_args.rope_scaling == "dynamic": logger.warning( @@ -129,9 +122,6 @@ def load_model_and_tokenizer( model_args.rope_scaling, scaling_factor )) - else: - logger.warning("Current model does not support RoPE scaling.") - # Set FlashAttention-2 if model_args.flash_attn: if getattr(config, "model_type", None) == "llama": @@ -155,7 +145,6 @@ def load_model_and_tokenizer( logger.warning("Current model does not support shift short attention.") # Quantization configurations (using bitsandbytes library). - is_mergeable = True if model_args.quantization_bit is not None: if is_deepspeed_zero3_enabled(): raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") @@ -165,7 +154,7 @@ def load_model_and_tokenizer( config_kwargs["load_in_8bit"] = True config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) - elif model_args.quantization_bit == 4: + if model_args.quantization_bit == 4: require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") config_kwargs["load_in_4bit"] = True config_kwargs["quantization_config"] = BitsAndBytesConfig( @@ -175,7 +164,6 @@ def load_model_and_tokenizer( bnb_4bit_quant_type=model_args.quantization_type ) - is_mergeable = False config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto" logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) @@ -207,7 +195,7 @@ def load_model_and_tokenizer( # Initialize adapters model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model - model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable) + model = init_adapter(model, model_args, finetuning_args, is_trainable) model = model.train() if is_trainable else model.eval() # Prepare model with valuehead for RLHF @@ -226,7 +214,7 @@ def load_model_and_tokenizer( logger.info("Load reward model from {}".format(model_args.reward_model)) if getattr(model, "is_peft_model", False): model.pretrained_model.load_adapter(model_args.reward_model, "reward") - assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded." + load_valuehead_params(model, model_args) # Prepare model for inference if not is_trainable: diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/tuner/core/parser.py index 603fc1bc..49c4e685 100644 --- a/src/llmtuner/tuner/core/parser.py +++ b/src/llmtuner/tuner/core/parser.py @@ -132,16 +132,12 @@ def get_train_args( if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": raise ValueError("Quantization is only compatible with the LoRA method.") - if model_args.checkpoint_dir is not None: - if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) != 1: - raise ValueError("Only LoRA tuning accepts multiple checkpoints.") - - if model_args.quantization_bit is not None: - if len(model_args.checkpoint_dir) != 1: - raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.") - - if not finetuning_args.resume_lora_training: - raise ValueError("Quantized model cannot create new LoRA weight. Merge them first.") + if ( + model_args.checkpoint_dir is not None + and len(model_args.checkpoint_dir) != 1 + and finetuning_args.finetuning_type != "lora" + ): + raise ValueError("Only LoRA tuning accepts multiple checkpoints.") if training_args.do_train and model_args.quantization_bit is not None and (not finetuning_args.upcast_layernorm): logger.warning("We recommend enable `upcast_layernorm` in quantized training.") @@ -216,11 +212,11 @@ def get_infer_args( if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": raise ValueError("Quantization is only compatible with the LoRA method.") - if model_args.checkpoint_dir is not None: - if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) != 1: - raise ValueError("Only LoRA tuning accepts multiple checkpoints.") - - if model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1: - raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.") + if ( + model_args.checkpoint_dir is not None + and len(model_args.checkpoint_dir) != 1 + and finetuning_args.finetuning_type != "lora" + ): + raise ValueError("Only LoRA tuning accepts multiple checkpoints.") return model_args, data_args, finetuning_args, generating_args diff --git a/src/llmtuner/tuner/core/utils.py b/src/llmtuner/tuner/core/utils.py index d9a1aac9..19fe42fd 100644 --- a/src/llmtuner/tuner/core/utils.py +++ b/src/llmtuner/tuner/core/utils.py @@ -1,13 +1,12 @@ import torch -from types import MethodType -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from llmtuner.extras.constants import LAYERNORM_NAMES from llmtuner.extras.logging import get_logger if TYPE_CHECKING: from transformers.modeling_utils import PreTrainedModel - from llmtuner.hparams import FinetuningArguments + from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments logger = get_logger(__name__) @@ -15,8 +14,7 @@ logger = get_logger(__name__) def find_all_linear_modules( model: "PreTrainedModel", - quantization_bit: Optional[int] = None, - output_layer_name: Optional[str] = "lm_head" + quantization_bit: Optional[int] = None ) -> List[str]: if quantization_bit is not None: import bitsandbytes as bnb @@ -24,17 +22,35 @@ def find_all_linear_modules( else: linear_cls = torch.nn.Linear + output_layer_names = ["lm_head"] + if model.config.model_type == "chatglm": + output_layer_names.append("output_layer") + module_names = set() for name, module in model.named_modules(): - if output_layer_name not in name and isinstance(module, linear_cls): + if ( + isinstance(module, linear_cls) + and not any([output_layer in name for output_layer in output_layer_names]) + ): module_names.add(name.split(".")[-1]) - if output_layer_name in module_names: - module_names.pop(output_layer_name) - + logger.info("Found linear modules: {}".format(",".join(module_names))) return list(module_names) +def generate_model_card( + model_args: "ModelArguments", + data_args: "DataArguments", + finetuning_args: "FinetuningArguments" +) -> Dict[str, Any]: + return { + "tasks": "text-generation", + "finetuned_from": model_args.model_name_or_path, + "dataset": [dataset.strip() for dataset in data_args.dataset.split(",")], + "tags": ["llama-factory"] + (["lora"] if finetuning_args.finetuning_type == "lora" else []) + } + + def prepare_model_for_training( model: "PreTrainedModel", finetuning_args: "FinetuningArguments", @@ -56,26 +72,21 @@ def prepare_model_for_training( logger.info("Upcasting weights in layernorm in float32.") if finetuning_args.neft_alpha > 1e-6: - input_embed = model.get_input_embeddings() - if isinstance(input_embed, torch.nn.Embedding): - def noisy_forward(self: torch.nn.Embedding, x: torch.Tensor) -> torch.Tensor: - embeddings = input_embed.__class__.forward(self, x) - if self.training: - dims = self.num_embeddings * self.embedding_dim - mag_norm = finetuning_args.neft_alpha / (dims ** 0.5) - embeddings += torch.zeros_like(embeddings).uniform_(-mag_norm, mag_norm) - return embeddings + def neftune_forward_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor): + if module.training: + dims = torch.tensor(output.size(1) * output.size(2)) + mag_norm = finetuning_args.neft_alpha / torch.sqrt(dims) + output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm) + return output - input_embed.forward = MethodType(noisy_forward, input_embed) - logger.info("Using noisy embedding with alpha={:.2f}".format(finetuning_args.neft_alpha)) - else: - logger.warning("Input embeddings are not normal nn.Embedding, cannot transform into noisy embedding.") + model.get_input_embeddings().register_forward_hook(neftune_forward_hook) + logger.info("Using noisy embedding with alpha={:.2f}".format(finetuning_args.neft_alpha)) if use_gradient_checkpointing: if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() else: - def make_inputs_require_grad(module: torch.nn.Module, input: torch.Tensor, output: torch.Tensor): + def make_inputs_require_grad(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor): output.requires_grad_(True) model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) @@ -86,9 +97,11 @@ def prepare_model_for_training( if finetuning_args.finetuning_type != "full" and hasattr(model, output_layer_name): output_layer = getattr(model, output_layer_name) if isinstance(output_layer, torch.nn.Linear): - def forward_in_fp32(self, x: torch.Tensor) -> torch.Tensor: - return output_layer.__class__.forward(self, x.to(output_layer.weight.dtype)).to(torch.float32) - - output_layer.forward = MethodType(forward_in_fp32, output_layer) + def fp32_forward_pre_hook(module: torch.nn.Module, args: Tuple[torch.Tensor]): + return args[0].to(output_layer.weight.dtype) + def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor): + return output.to(torch.float32) + output_layer.register_forward_pre_hook(fp32_forward_pre_hook) + output_layer.register_forward_hook(fp32_forward_post_hook) return model diff --git a/src/llmtuner/tuner/dpo/workflow.py b/src/llmtuner/tuner/dpo/workflow.py index 6e16dd18..63968604 100644 --- a/src/llmtuner/tuner/dpo/workflow.py +++ b/src/llmtuner/tuner/dpo/workflow.py @@ -8,7 +8,7 @@ from transformers import Seq2SeqTrainingArguments from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.ploting import plot_loss -from llmtuner.tuner.core import load_model_and_tokenizer +from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding from llmtuner.tuner.dpo.trainer import CustomDPOTrainer @@ -52,13 +52,18 @@ def run_dpo( # Training if training_args.do_train: train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + trainer.save_model() trainer.log_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics) trainer.save_state() - trainer.save_model() if trainer.is_world_process_zero() and model_args.plot_loss: plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) + if training_args.push_to_hub: + trainer.push_to_hub(**generate_model_card()) + else: + trainer.create_model_card(**generate_model_card()) + # Evaluation if training_args.do_eval: metrics = trainer.evaluate(metric_key_prefix="eval") diff --git a/src/llmtuner/tuner/pt/workflow.py b/src/llmtuner/tuner/pt/workflow.py index 66d08de7..002d2dd1 100644 --- a/src/llmtuner/tuner/pt/workflow.py +++ b/src/llmtuner/tuner/pt/workflow.py @@ -1,4 +1,4 @@ -# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py +# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/language-modeling/run_clm.py import math from typing import TYPE_CHECKING, Optional, List @@ -6,7 +6,7 @@ from transformers import DataCollatorForLanguageModeling, Trainer from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.ploting import plot_loss -from llmtuner.tuner.core import load_model_and_tokenizer +from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback @@ -38,13 +38,18 @@ def run_pt( # Training if training_args.do_train: train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + trainer.save_model() trainer.log_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics) trainer.save_state() - trainer.save_model() if trainer.is_world_process_zero() and model_args.plot_loss: plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) + if training_args.push_to_hub: + trainer.push_to_hub(**generate_model_card()) + else: + trainer.create_model_card(**generate_model_card()) + # Evaluation if training_args.do_eval: metrics = trainer.evaluate(metric_key_prefix="eval") diff --git a/src/llmtuner/tuner/rm/workflow.py b/src/llmtuner/tuner/rm/workflow.py index 6d2c4422..c95f1cb6 100644 --- a/src/llmtuner/tuner/rm/workflow.py +++ b/src/llmtuner/tuner/rm/workflow.py @@ -1,5 +1,4 @@ -# Inspired by: -# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py +# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py from typing import TYPE_CHECKING, Optional, List from transformers import Seq2SeqTrainingArguments @@ -7,7 +6,7 @@ from transformers import Seq2SeqTrainingArguments from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.callbacks import SavePeftModelCallback from llmtuner.extras.ploting import plot_loss -from llmtuner.tuner.core import load_model_and_tokenizer +from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer from llmtuner.tuner.rm.metric import compute_accuracy from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding from llmtuner.tuner.rm.trainer import PairwiseTrainer @@ -47,13 +46,18 @@ def run_rm( # Training if training_args.do_train: train_result = trainer.train() + trainer.save_model() trainer.log_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics) trainer.save_state() - trainer.save_model() if trainer.is_world_process_zero() and model_args.plot_loss: plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) + if training_args.push_to_hub: + trainer.push_to_hub(**generate_model_card()) + else: + trainer.create_model_card(**generate_model_card()) + # Evaluation if training_args.do_eval: metrics = trainer.evaluate(metric_key_prefix="eval") diff --git a/src/llmtuner/tuner/sft/workflow.py b/src/llmtuner/tuner/sft/workflow.py index 8d53605d..dc22904b 100644 --- a/src/llmtuner/tuner/sft/workflow.py +++ b/src/llmtuner/tuner/sft/workflow.py @@ -1,4 +1,4 @@ -# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py +# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py from typing import TYPE_CHECKING, Optional, List from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments @@ -7,7 +7,7 @@ from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.misc import get_logits_processor from llmtuner.extras.ploting import plot_loss -from llmtuner.tuner.core import load_model_and_tokenizer +from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer from llmtuner.tuner.sft.metric import ComputeMetrics from llmtuner.tuner.sft.trainer import CustomSeq2SeqTrainer @@ -65,13 +65,18 @@ def run_sft( # Training if training_args.do_train: train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + trainer.save_model() trainer.log_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics) trainer.save_state() - trainer.save_model() if trainer.is_world_process_zero() and model_args.plot_loss: plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) + if training_args.push_to_hub: + trainer.push_to_hub(**generate_model_card()) + else: + trainer.create_model_card(**generate_model_card()) + # Evaluation if training_args.do_eval: metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)