mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	fix llava rlhf
Former-commit-id: f6863cbbcbf960d6481296c6cae3e40fd70e4e14
This commit is contained in:
		
							parent
							
								
									a412b4ed4a
								
							
						
					
					
						commit
						4dcd47100d
					
				@ -1,5 +1,6 @@
 | 
			
		||||
from .loader import load_config, load_model, load_tokenizer
 | 
			
		||||
from .utils.misc import find_all_linear_modules, load_valuehead_params
 | 
			
		||||
from .utils.misc import find_all_linear_modules
 | 
			
		||||
from .utils.valuehead import load_valuehead_params
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
 | 
			
		||||
@ -7,9 +7,10 @@ from ..extras.logging import get_logger
 | 
			
		||||
from ..extras.misc import count_parameters, try_download_model_from_ms
 | 
			
		||||
from .adapter import init_adapter
 | 
			
		||||
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
 | 
			
		||||
from .utils.misc import load_valuehead_params, register_autoclass
 | 
			
		||||
from .utils.misc import register_autoclass
 | 
			
		||||
from .utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
 | 
			
		||||
from .utils.unsloth import load_unsloth_pretrained_model
 | 
			
		||||
from .utils.valuehead import load_valuehead_params
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
@ -105,7 +106,7 @@ def load_model(
 | 
			
		||||
    """
 | 
			
		||||
    init_kwargs = _get_init_kwargs(model_args)
 | 
			
		||||
    config = load_config(model_args)
 | 
			
		||||
    patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
 | 
			
		||||
    patch_config(config, tokenizer, model_args, init_kwargs, is_trainable, add_valuehead)
 | 
			
		||||
 | 
			
		||||
    model = None
 | 
			
		||||
    lazy_load = False
 | 
			
		||||
@ -130,7 +131,7 @@ def load_model(
 | 
			
		||||
            model = convert_pretrained_model_to_mod(model, config, model_args)
 | 
			
		||||
 | 
			
		||||
    if not lazy_load:
 | 
			
		||||
        patch_model(model, tokenizer, model_args, is_trainable)
 | 
			
		||||
        patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)
 | 
			
		||||
        register_autoclass(config, model, tokenizer)
 | 
			
		||||
 | 
			
		||||
    model = init_adapter(config, model, model_args, finetuning_args, is_trainable)
 | 
			
		||||
 | 
			
		||||
@ -15,6 +15,7 @@ from .utils.longlora import configure_longlora
 | 
			
		||||
from .utils.moe import add_z3_leaf_module, configure_moe
 | 
			
		||||
from .utils.quantization import configure_quantization
 | 
			
		||||
from .utils.rope import configure_rope
 | 
			
		||||
from .utils.valuehead import configure_valuehead, prepare_valuehead_model
 | 
			
		||||
from .utils.visual import autocast_projector_dtype
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -39,6 +40,7 @@ def patch_config(
 | 
			
		||||
    model_args: "ModelArguments",
 | 
			
		||||
    init_kwargs: Dict[str, Any],
 | 
			
		||||
    is_trainable: bool,
 | 
			
		||||
    add_valuehead: bool,
 | 
			
		||||
) -> None:
 | 
			
		||||
    if model_args.compute_dtype is None:  # priority: bf16 > fp16 > fp32
 | 
			
		||||
        model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
 | 
			
		||||
@ -49,6 +51,9 @@ def patch_config(
 | 
			
		||||
    configure_quantization(config, tokenizer, model_args, init_kwargs)
 | 
			
		||||
    configure_moe(config, model_args, is_trainable)
 | 
			
		||||
 | 
			
		||||
    if add_valuehead:
 | 
			
		||||
        configure_valuehead(config)
 | 
			
		||||
 | 
			
		||||
    if model_args.use_cache and not is_trainable:
 | 
			
		||||
        setattr(config, "use_cache", True)
 | 
			
		||||
        logger.info("Using KV cache for faster generation.")
 | 
			
		||||
@ -73,7 +78,11 @@ def patch_config(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def patch_model(
 | 
			
		||||
    model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool
 | 
			
		||||
    model: "PreTrainedModel",
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
    model_args: "ModelArguments",
 | 
			
		||||
    is_trainable: bool,
 | 
			
		||||
    add_valuehead: bool,
 | 
			
		||||
) -> None:
 | 
			
		||||
    gen_config = model.generation_config  # check and fix generation config
 | 
			
		||||
    if not gen_config.do_sample and (
 | 
			
		||||
@ -86,9 +95,8 @@ def patch_model(
 | 
			
		||||
    if "GenerationMixin" not in str(model.generate.__func__):
 | 
			
		||||
        model.generate = MethodType(PreTrainedModel.generate, model)
 | 
			
		||||
 | 
			
		||||
    if is_trainable and getattr(model.config, "model_type", None) == "chatglm":
 | 
			
		||||
        setattr(model, "lm_head", model.transformer.output_layer)
 | 
			
		||||
        setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
 | 
			
		||||
    if add_valuehead:
 | 
			
		||||
        prepare_valuehead_model(model)
 | 
			
		||||
 | 
			
		||||
    if model_args.resize_vocab:
 | 
			
		||||
        resize_embedding_layer(model, tokenizer)
 | 
			
		||||
 | 
			
		||||
@ -1,18 +1,13 @@
 | 
			
		||||
from typing import TYPE_CHECKING, Dict, List
 | 
			
		||||
from typing import TYPE_CHECKING, List
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from transformers import PreTrainedModel
 | 
			
		||||
from transformers.utils import cached_file
 | 
			
		||||
 | 
			
		||||
from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
 | 
			
		||||
from ...extras.logging import get_logger
 | 
			
		||||
from .quantization import QuantizationMethod
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from transformers import PretrainedConfig, PreTrainedTokenizer
 | 
			
		||||
 | 
			
		||||
    from ...hparams import ModelArguments
 | 
			
		||||
    from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
@ -74,34 +69,6 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
 | 
			
		||||
    return module_names
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Loads value head parameters from Hugging Face Hub or local disk.
 | 
			
		||||
 | 
			
		||||
    Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
 | 
			
		||||
    """
 | 
			
		||||
    kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        from safetensors import safe_open
 | 
			
		||||
 | 
			
		||||
        vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
 | 
			
		||||
        with safe_open(vhead_file, framework="pt", device="cpu") as f:
 | 
			
		||||
            return {key: f.get_tensor(key) for key in f.keys()}
 | 
			
		||||
    except Exception as err:
 | 
			
		||||
        logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err)))
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
 | 
			
		||||
        return torch.load(vhead_file, map_location="cpu")
 | 
			
		||||
    except Exception as err:
 | 
			
		||||
        logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err)))
 | 
			
		||||
 | 
			
		||||
    logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id))
 | 
			
		||||
    logger.info("Ignore these messages if you are not resuming the training of a value head model.")
 | 
			
		||||
    return None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"):
 | 
			
		||||
    if "AutoConfig" in getattr(config, "auto_map", {}):
 | 
			
		||||
        config.__class__.register_for_auto_class()
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										59
									
								
								src/llmtuner/model/utils/valuehead.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								src/llmtuner/model/utils/valuehead.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,59 @@
 | 
			
		||||
from typing import TYPE_CHECKING, Dict
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from transformers.utils import cached_file
 | 
			
		||||
 | 
			
		||||
from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
 | 
			
		||||
from ...extras.logging import get_logger
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from transformers import PretrainedConfig, PreTrainedModel
 | 
			
		||||
 | 
			
		||||
    from ...hparams import ModelArguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def configure_valuehead(config: "PretrainedConfig") -> None:
 | 
			
		||||
    if getattr(config, "model_type", None) == "llava":
 | 
			
		||||
        setattr(config, "hidden_size", getattr(config.vision_config, "intermediate_size", None))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Loads value head parameters from Hugging Face Hub or local disk.
 | 
			
		||||
 | 
			
		||||
    Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
 | 
			
		||||
    """
 | 
			
		||||
    kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        from safetensors import safe_open
 | 
			
		||||
 | 
			
		||||
        vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
 | 
			
		||||
        with safe_open(vhead_file, framework="pt", device="cpu") as f:
 | 
			
		||||
            return {key: f.get_tensor(key) for key in f.keys()}
 | 
			
		||||
    except Exception as err:
 | 
			
		||||
        logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err)))
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
 | 
			
		||||
        return torch.load(vhead_file, map_location="cpu")
 | 
			
		||||
    except Exception as err:
 | 
			
		||||
        logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err)))
 | 
			
		||||
 | 
			
		||||
    logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id))
 | 
			
		||||
    logger.info("Ignore these messages if you are not resuming the training of a value head model.")
 | 
			
		||||
    return None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def prepare_valuehead_model(model: "PreTrainedModel") -> None:
 | 
			
		||||
    if getattr(model.config, "model_type", None) == "llava":
 | 
			
		||||
        setattr(model, "lm_head", model.language_model.get_output_embeddings())
 | 
			
		||||
        setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
 | 
			
		||||
 | 
			
		||||
    if getattr(model.config, "model_type", None) == "chatglm":
 | 
			
		||||
        setattr(model, "lm_head", model.transformer.output_layer)
 | 
			
		||||
        setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user