mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 03:10:35 +08:00
@@ -21,12 +21,12 @@ from trl import AutoModelForCausalLMWithValueHead
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_ms
|
||||
from .adapter import init_adapter
|
||||
from .model_utils.liger_kernel import apply_liger_kernel
|
||||
from .model_utils.misc import register_autoclass
|
||||
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
|
||||
from .model_utils.unsloth import load_unsloth_pretrained_model
|
||||
from .model_utils.valuehead import load_valuehead_params
|
||||
from .model_utils.visual import get_image_seqlen
|
||||
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
|
||||
from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -61,7 +61,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
||||
|
||||
def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
r"""
|
||||
Loads pretrained tokenizer.
|
||||
Loads pretrained tokenizer and optionally loads processor.
|
||||
|
||||
Note: including inplace operation of model_args.
|
||||
"""
|
||||
@@ -96,15 +96,9 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
logger.warning("New tokens have been added, changed `resize_vocab` to True.")
|
||||
|
||||
patch_tokenizer(tokenizer)
|
||||
|
||||
try:
|
||||
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||
setattr(processor, "tokenizer", tokenizer)
|
||||
setattr(processor, "image_seqlen", get_image_seqlen(config))
|
||||
setattr(processor, "image_resolution", model_args.image_resolution)
|
||||
setattr(processor, "video_resolution", model_args.video_resolution)
|
||||
setattr(processor, "video_fps", model_args.video_fps)
|
||||
setattr(processor, "video_maxlen", model_args.video_maxlen)
|
||||
patch_processor(processor, config, tokenizer, model_args)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load processor. Error: {}".format(e))
|
||||
processor = None
|
||||
@@ -138,6 +132,7 @@ def load_model(
|
||||
init_kwargs = _get_init_kwargs(model_args)
|
||||
config = load_config(model_args)
|
||||
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
|
||||
apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))
|
||||
|
||||
model = None
|
||||
lazy_load = False
|
||||
@@ -158,7 +153,6 @@ def load_model(
|
||||
load_class = AutoModelForVision2Seq
|
||||
else:
|
||||
load_class = AutoModelForCausalLM
|
||||
|
||||
if model_args.train_from_scratch:
|
||||
model = load_class.from_config(config)
|
||||
else:
|
||||
|
||||
@@ -37,10 +37,11 @@ def configure_attn_implementation(
|
||||
if is_flash_attn_2_available():
|
||||
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
|
||||
require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3")
|
||||
logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
|
||||
model_args.flash_attn = "fa2"
|
||||
if model_args.flash_attn != "fa2":
|
||||
logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
|
||||
model_args.flash_attn = "fa2"
|
||||
else:
|
||||
logger.warning("Gemma-2 should use eager attention, change `flash_attn` to disabled.")
|
||||
logger.warning("FlashAttention-2 is not installed, use eager attention.")
|
||||
model_args.flash_attn = "disabled"
|
||||
elif model_args.flash_attn == "sdpa":
|
||||
logger.warning("Gemma-2 should use soft-capping attention, while the SDPA attention does not support it.")
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
@@ -26,7 +27,12 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
def apply_liger_kernel(
|
||||
config: "PretrainedConfig",
|
||||
model_args: "ModelArguments",
|
||||
is_trainable: bool,
|
||||
require_logits: bool,
|
||||
) -> None:
|
||||
if not is_trainable or not model_args.enable_liger_kernel:
|
||||
return
|
||||
|
||||
@@ -51,5 +57,11 @@ def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArgumen
|
||||
logger.warning("Current model does not support liger kernel.")
|
||||
return
|
||||
|
||||
apply_liger_kernel()
|
||||
if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters:
|
||||
logger.info("Current training stage does not support chunked cross entropy.")
|
||||
kwargs = {"fused_linear_cross_entropy": False}
|
||||
else:
|
||||
kwargs = {}
|
||||
|
||||
apply_liger_kernel(**kwargs)
|
||||
logger.info("Liger kernel has been applied to the model.")
|
||||
|
||||
@@ -34,7 +34,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
|
||||
forbidden_modules.add("output_layer")
|
||||
elif model_type == "internlm2":
|
||||
forbidden_modules.add("output")
|
||||
elif model_type in ["llava", "paligemma"]:
|
||||
elif model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
|
||||
forbidden_modules.add("multi_modal_projector")
|
||||
elif model_type == "qwen2_vl":
|
||||
forbidden_modules.add("merger")
|
||||
|
||||
@@ -92,7 +92,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
|
||||
|
||||
if getattr(model, "quantization_method", None):
|
||||
model_type = getattr(model.config, "model_type", None)
|
||||
if model_type in ["llava", "paligemma"]:
|
||||
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
|
||||
mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector")
|
||||
elif model_type == "qwen2_vl":
|
||||
mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger")
|
||||
@@ -108,7 +108,13 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
|
||||
Patches VLMs before loading them.
|
||||
"""
|
||||
model_type = getattr(config, "model_type", None)
|
||||
if model_type == "llava": # required for ds zero3 and valuehead models
|
||||
if model_type in [
|
||||
"llava",
|
||||
"llava_next",
|
||||
"llava_next_video",
|
||||
"paligemma",
|
||||
"video_llava",
|
||||
]: # required for ds zero3 and valuehead models
|
||||
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
|
||||
|
||||
if getattr(config, "is_yi_vl_derived_model", None):
|
||||
@@ -122,7 +128,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
|
||||
"""
|
||||
model_type = getattr(config, "model_type", None)
|
||||
forbidden_modules = set()
|
||||
if model_type in ["llava", "paligemma"]:
|
||||
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
|
||||
if finetuning_args.freeze_vision_tower:
|
||||
forbidden_modules.add("vision_tower")
|
||||
|
||||
@@ -150,12 +156,28 @@ def get_image_seqlen(config: "PretrainedConfig") -> int:
|
||||
image_seqlen += 1
|
||||
elif model_type == "paligemma":
|
||||
image_seqlen = config.vision_config.num_image_tokens
|
||||
elif model_type == "qwen2_vl": # variable length
|
||||
else:
|
||||
image_seqlen = -1
|
||||
|
||||
return image_seqlen
|
||||
|
||||
|
||||
def get_patch_size(config: "PretrainedConfig") -> int:
|
||||
r"""
|
||||
Computes the patch size of the vit.
|
||||
"""
|
||||
patch_size = getattr(config.vision_config, "patch_size", -1)
|
||||
return patch_size
|
||||
|
||||
|
||||
def get_vision_feature_select_strategy(config: "PretrainedConfig") -> int:
|
||||
r"""
|
||||
Get the vision_feature_select_strategy.
|
||||
"""
|
||||
vision_feature_select_strategy = getattr(config, "vision_feature_select_strategy", "default")
|
||||
return vision_feature_select_strategy
|
||||
|
||||
|
||||
def patch_target_modules(
|
||||
config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
|
||||
) -> Union[str, List[str]]:
|
||||
@@ -164,7 +186,7 @@ def patch_target_modules(
|
||||
"""
|
||||
model_type = getattr(config, "model_type", None)
|
||||
if finetuning_args.freeze_vision_tower:
|
||||
if model_type in ["llava", "paligemma"]:
|
||||
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
|
||||
return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
|
||||
elif model_type == "qwen2_vl":
|
||||
return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules))
|
||||
|
||||
@@ -27,18 +27,23 @@ from ..extras.misc import infer_optim_dtype
|
||||
from .model_utils.attention import configure_attn_implementation, print_attn_implementation
|
||||
from .model_utils.checkpointing import prepare_model_for_training
|
||||
from .model_utils.embedding import resize_embedding_layer
|
||||
from .model_utils.liger_kernel import configure_liger_kernel
|
||||
from .model_utils.longlora import configure_longlora
|
||||
from .model_utils.moe import add_z3_leaf_module, configure_moe
|
||||
from .model_utils.packing import configure_packing
|
||||
from .model_utils.quantization import configure_quantization
|
||||
from .model_utils.rope import configure_rope
|
||||
from .model_utils.valuehead import prepare_valuehead_model
|
||||
from .model_utils.visual import autocast_projector_dtype, configure_visual_model
|
||||
from .model_utils.visual import (
|
||||
autocast_projector_dtype,
|
||||
configure_visual_model,
|
||||
get_image_seqlen,
|
||||
get_patch_size,
|
||||
get_vision_feature_select_strategy,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||
from transformers import PretrainedConfig, PreTrainedTokenizer, ProcessorMixin
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..hparams import ModelArguments
|
||||
@@ -52,6 +57,22 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
|
||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||
|
||||
|
||||
def patch_processor(
|
||||
processor: "ProcessorMixin",
|
||||
config: "PretrainedConfig",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
) -> None:
|
||||
setattr(processor, "tokenizer", tokenizer)
|
||||
setattr(processor, "image_seqlen", get_image_seqlen(config))
|
||||
setattr(processor, "image_resolution", model_args.image_resolution)
|
||||
setattr(processor, "patch_size", get_patch_size(config))
|
||||
setattr(processor, "video_resolution", model_args.video_resolution)
|
||||
setattr(processor, "video_fps", model_args.video_fps)
|
||||
setattr(processor, "video_maxlen", model_args.video_maxlen)
|
||||
setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config))
|
||||
|
||||
|
||||
def patch_config(
|
||||
config: "PretrainedConfig",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
@@ -71,7 +92,6 @@ def patch_config(
|
||||
|
||||
configure_attn_implementation(config, model_args, is_trainable)
|
||||
configure_rope(config, model_args, is_trainable)
|
||||
configure_liger_kernel(config, model_args, is_trainable)
|
||||
configure_longlora(config, model_args, is_trainable)
|
||||
configure_quantization(config, tokenizer, model_args, init_kwargs)
|
||||
configure_moe(config, model_args, is_trainable)
|
||||
@@ -90,6 +110,9 @@ def patch_config(
|
||||
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
|
||||
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
|
||||
|
||||
if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):
|
||||
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
|
||||
|
||||
# deepspeed zero3 is not compatible with low_cpu_mem_usage
|
||||
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user