mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-24 23:02:49 +08:00
fix llava config
Former-commit-id: b033232aeaa1890ec6946387608aad4779a7ba10
This commit is contained in:
parent
ab94060839
commit
88159688bb
@ -46,6 +46,9 @@ def init_adapter(
|
|||||||
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
|
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
|
||||||
model = model.float()
|
model = model.float()
|
||||||
|
|
||||||
|
if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model
|
||||||
|
model.vision_tower.requires_grad_(False)
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
||||||
logger.info("Fine-tuning method: Freeze")
|
logger.info("Fine-tuning method: Freeze")
|
||||||
num_layers = (
|
num_layers = (
|
||||||
|
@ -106,7 +106,7 @@ def load_model(
|
|||||||
"""
|
"""
|
||||||
init_kwargs = _get_init_kwargs(model_args)
|
init_kwargs = _get_init_kwargs(model_args)
|
||||||
config = load_config(model_args)
|
config = load_config(model_args)
|
||||||
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable, add_valuehead)
|
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
|
||||||
|
|
||||||
model = None
|
model = None
|
||||||
lazy_load = False
|
lazy_load = False
|
||||||
|
@ -15,8 +15,8 @@ from .utils.longlora import configure_longlora
|
|||||||
from .utils.moe import add_z3_leaf_module, configure_moe
|
from .utils.moe import add_z3_leaf_module, configure_moe
|
||||||
from .utils.quantization import configure_quantization
|
from .utils.quantization import configure_quantization
|
||||||
from .utils.rope import configure_rope
|
from .utils.rope import configure_rope
|
||||||
from .utils.valuehead import configure_valuehead, prepare_valuehead_model
|
from .utils.valuehead import prepare_valuehead_model
|
||||||
from .utils.visual import autocast_projector_dtype
|
from .utils.visual import autocast_projector_dtype, configure_hidden_size
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -40,7 +40,6 @@ def patch_config(
|
|||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
init_kwargs: Dict[str, Any],
|
init_kwargs: Dict[str, Any],
|
||||||
is_trainable: bool,
|
is_trainable: bool,
|
||||||
add_valuehead: bool,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
||||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||||
@ -50,9 +49,7 @@ def patch_config(
|
|||||||
configure_longlora(config, model_args, is_trainable)
|
configure_longlora(config, model_args, is_trainable)
|
||||||
configure_quantization(config, tokenizer, model_args, init_kwargs)
|
configure_quantization(config, tokenizer, model_args, init_kwargs)
|
||||||
configure_moe(config, model_args, is_trainable)
|
configure_moe(config, model_args, is_trainable)
|
||||||
|
configure_hidden_size(config)
|
||||||
if add_valuehead:
|
|
||||||
configure_valuehead(config)
|
|
||||||
|
|
||||||
if model_args.use_cache and not is_trainable:
|
if model_args.use_cache and not is_trainable:
|
||||||
setattr(config, "use_cache", True)
|
setattr(config, "use_cache", True)
|
||||||
|
@ -8,7 +8,7 @@ from ...extras.logging import get_logger
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PretrainedConfig, PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
@ -16,11 +16,6 @@ if TYPE_CHECKING:
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def configure_valuehead(config: "PretrainedConfig") -> None:
|
|
||||||
if getattr(config, "model_type", None) == "llava":
|
|
||||||
setattr(config, "hidden_size", getattr(config.vision_config, "intermediate_size", None))
|
|
||||||
|
|
||||||
|
|
||||||
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
||||||
r"""
|
r"""
|
||||||
Loads value head parameters from Hugging Face Hub or local disk.
|
Loads value head parameters from Hugging Face Hub or local disk.
|
||||||
|
@ -6,7 +6,7 @@ from ...extras.logging import get_logger
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedModel
|
from transformers import PretrainedConfig, PreTrainedModel
|
||||||
|
|
||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
@ -14,6 +14,11 @@ if TYPE_CHECKING:
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def configure_hidden_size(config: "PretrainedConfig") -> None:
|
||||||
|
if getattr(config, "model_type", None) == "llava":
|
||||||
|
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
|
||||||
|
|
||||||
|
|
||||||
def autocast_projector_dtype(
|
def autocast_projector_dtype(
|
||||||
model: "PreTrainedModel", model_args: "ModelArguments", mm_projector_name: str = "multi_modal_projector"
|
model: "PreTrainedModel", model_args: "ModelArguments", mm_projector_name: str = "multi_modal_projector"
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -22,7 +27,7 @@ def autocast_projector_dtype(
|
|||||||
) -> "torch.Tensor":
|
) -> "torch.Tensor":
|
||||||
return output.to(model_args.compute_dtype)
|
return output.to(model_args.compute_dtype)
|
||||||
|
|
||||||
if hasattr(model, mm_projector_name):
|
if hasattr(model, mm_projector_name) and getattr(model.config, "quantization_method", None):
|
||||||
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
|
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
|
||||||
mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
|
mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
|
||||||
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
|
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user