mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 03:40:34 +08:00
fix llava config
This commit is contained in:
@@ -15,8 +15,8 @@ from .utils.longlora import configure_longlora
|
||||
from .utils.moe import add_z3_leaf_module, configure_moe
|
||||
from .utils.quantization import configure_quantization
|
||||
from .utils.rope import configure_rope
|
||||
from .utils.valuehead import configure_valuehead, prepare_valuehead_model
|
||||
from .utils.visual import autocast_projector_dtype
|
||||
from .utils.valuehead import prepare_valuehead_model
|
||||
from .utils.visual import autocast_projector_dtype, configure_hidden_size
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -40,7 +40,6 @@ def patch_config(
|
||||
model_args: "ModelArguments",
|
||||
init_kwargs: Dict[str, Any],
|
||||
is_trainable: bool,
|
||||
add_valuehead: bool,
|
||||
) -> None:
|
||||
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
@@ -50,9 +49,7 @@ def patch_config(
|
||||
configure_longlora(config, model_args, is_trainable)
|
||||
configure_quantization(config, tokenizer, model_args, init_kwargs)
|
||||
configure_moe(config, model_args, is_trainable)
|
||||
|
||||
if add_valuehead:
|
||||
configure_valuehead(config)
|
||||
configure_hidden_size(config)
|
||||
|
||||
if model_args.use_cache and not is_trainable:
|
||||
setattr(config, "use_cache", True)
|
||||
|
||||
Reference in New Issue
Block a user