mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	fix RM save model
Former-commit-id: 8104cc2425431eb1cddccf3909855296116f922b
This commit is contained in:
		
							parent
							
								
									9bba01a033
								
							
						
					
					
						commit
						8e26eb374e
					
				@ -128,6 +128,7 @@ Note: please update `data/dataset_info.json` to use your custom dataset. About t
 | 
			
		||||
### Dependence Installation (optional)
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
git lfs install
 | 
			
		||||
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
 | 
			
		||||
conda create -n llama_etuning python=3.10
 | 
			
		||||
conda activate llama_etuning
 | 
			
		||||
 | 
			
		||||
@ -128,6 +128,7 @@ huggingface-cli login
 | 
			
		||||
### 环境搭建(可跳过)
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
git lfs install
 | 
			
		||||
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
 | 
			
		||||
conda create -n llama_etuning python=3.10
 | 
			
		||||
conda activate llama_etuning
 | 
			
		||||
 | 
			
		||||
@ -16,8 +16,16 @@ class LoggerHandler(logging.Handler):
 | 
			
		||||
        self.log += "\n\n"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_logger(name: str) -> logging.Logger:
 | 
			
		||||
def reset_logging():
 | 
			
		||||
    r"""
 | 
			
		||||
    Removes basic config of root logger
 | 
			
		||||
    """
 | 
			
		||||
    root = logging.getLogger()
 | 
			
		||||
    list(map(root.removeHandler, root.handlers))
 | 
			
		||||
    list(map(root.removeFilter, root.filters))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_logger(name: str) -> logging.Logger:
 | 
			
		||||
    formatter = logging.Formatter(
 | 
			
		||||
        fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
 | 
			
		||||
        datefmt="%m/%d/%Y %H:%M:%S"
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@ from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
 | 
			
		||||
from transformers.tokenization_utils import PreTrainedTokenizerBase
 | 
			
		||||
from trl import AutoModelForCausalLMWithValueHead
 | 
			
		||||
 | 
			
		||||
from llmtuner.extras.logging import get_logger
 | 
			
		||||
from llmtuner.extras.logging import reset_logging, get_logger
 | 
			
		||||
from llmtuner.extras.misc import count_parameters, prepare_model_for_training
 | 
			
		||||
from llmtuner.extras.save_and_load import load_valuehead_params
 | 
			
		||||
from llmtuner.hparams import FinetuningArguments
 | 
			
		||||
@ -95,7 +95,10 @@ def load_model_and_tokenizer(
 | 
			
		||||
        is_mergeable = False
 | 
			
		||||
        logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
 | 
			
		||||
 | 
			
		||||
    if model_args.quantization_bit is not None or os.environ.get("LOCAL_RANK") is not None:
 | 
			
		||||
    if (
 | 
			
		||||
        model_args.quantization_bit is not None
 | 
			
		||||
        or (os.environ.get('LOCAL_RANK') is not None and not is_deepspeed_zero3_enabled())
 | 
			
		||||
    ):
 | 
			
		||||
        config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
 | 
			
		||||
 | 
			
		||||
    if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
 | 
			
		||||
@ -126,6 +129,7 @@ def load_model_and_tokenizer(
 | 
			
		||||
 | 
			
		||||
    if stage == "rm" or stage == "ppo": # add value head
 | 
			
		||||
        model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
 | 
			
		||||
        reset_logging()
 | 
			
		||||
 | 
			
		||||
        if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
 | 
			
		||||
            logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
 | 
			
		||||
 | 
			
		||||
@ -85,6 +85,9 @@ def get_train_args(
 | 
			
		||||
    assert training_args.evaluation_strategy == "no" or (not data_args.streaming), \
 | 
			
		||||
        "Streaming mode does not support evaluation currently."
 | 
			
		||||
 | 
			
		||||
    assert not (general_args.stage == "ppo" and data_args.streaming), \
 | 
			
		||||
        "Streaming mode does not suppport PPO training currently."
 | 
			
		||||
 | 
			
		||||
    if model_args.checkpoint_dir is not None:
 | 
			
		||||
        if finetuning_args.finetuning_type != "lora":
 | 
			
		||||
            assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
 | 
			
		||||
@ -107,8 +110,8 @@ def get_train_args(
 | 
			
		||||
        training_args.ddp_find_unused_parameters = False
 | 
			
		||||
 | 
			
		||||
    if data_args.max_samples is not None and data_args.streaming:
 | 
			
		||||
        logger.warning("`max_samples` is incompatible with `streaming`. Disabling streaming mode.")
 | 
			
		||||
        data_args.streaming = False
 | 
			
		||||
        logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
 | 
			
		||||
        data_args.max_samples = None
 | 
			
		||||
 | 
			
		||||
    if data_args.dev_ratio > 1e-6 and data_args.streaming:
 | 
			
		||||
        logger.warning("`dev_ratio` is incompatible with `streaming`. Disabling development set.")
 | 
			
		||||
 | 
			
		||||
@ -47,20 +47,19 @@ class PeftTrainer(Seq2SeqTrainer):
 | 
			
		||||
        logger.info(f"Saving model checkpoint to {output_dir}")
 | 
			
		||||
 | 
			
		||||
        model = unwrap_model(self.model)
 | 
			
		||||
        state_dict = state_dict or get_state_dict(model)
 | 
			
		||||
 | 
			
		||||
        if isinstance(model, PreTrainedModelWrapper):
 | 
			
		||||
            model_params, v_head_params = {}, {}
 | 
			
		||||
            for name in state_dict.keys():
 | 
			
		||||
                if name.startswith("pretrained_model."):
 | 
			
		||||
                    model_params[name.replace("pretrained_model.", "")] = state_dict[name]
 | 
			
		||||
                elif name.startswith("v_head."):
 | 
			
		||||
                    v_head_params[name.replace("v_head.", "")] = state_dict[name]
 | 
			
		||||
            # Custom state dict: https://github.com/lvwerra/trl/blob/v0.4.7/trl/models/modeling_value_head.py#L200
 | 
			
		||||
            model_state_dict = state_dict or model.state_dict()
 | 
			
		||||
            v_head_state_dict = {
 | 
			
		||||
                name.replace("v_head.", ""): model_state_dict[name].cpu().clone().detach()
 | 
			
		||||
                for name in model_state_dict.keys() if name.startswith("v_head.")
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            torch.save(v_head_params, os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
 | 
			
		||||
            state_dict = model_params
 | 
			
		||||
            torch.save(v_head_state_dict, os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
 | 
			
		||||
            model = model.pretrained_model
 | 
			
		||||
 | 
			
		||||
        state_dict = state_dict or get_state_dict(model)
 | 
			
		||||
        if isinstance(model, (PeftModel, PreTrainedModel)):
 | 
			
		||||
            model.config.use_cache = True
 | 
			
		||||
            model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors)
 | 
			
		||||
 | 
			
		||||
@ -11,7 +11,7 @@ from trl import PPOTrainer
 | 
			
		||||
from trl.core import LengthSampler
 | 
			
		||||
 | 
			
		||||
from llmtuner.extras.logging import get_logger
 | 
			
		||||
from llmtuner.extras.misc import AverageMeter, get_logits_processor
 | 
			
		||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
 | 
			
		||||
 | 
			
		||||
from llmtuner.tuner.core.trainer import PeftTrainer
 | 
			
		||||
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
 | 
			
		||||
@ -29,6 +29,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
 | 
			
		||||
    r"""
 | 
			
		||||
    Inherits PPOTrainer.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        training_args: "Seq2SeqTrainingArguments",
 | 
			
		||||
@ -70,7 +71,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
 | 
			
		||||
            logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
 | 
			
		||||
            logger.info(f"  Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
 | 
			
		||||
            logger.info(f"  Total optimization steps = {max_steps}")
 | 
			
		||||
            logger.info(f"  Number of trainable parameters = {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}")
 | 
			
		||||
            logger.info(f"  Number of trainable parameters = {count_parameters(self.model)[0]}")
 | 
			
		||||
 | 
			
		||||
        # Keyword arguments for `model.generate`
 | 
			
		||||
        gen_kwargs = {
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user