Merge branch 'main' into cpei/refactor

Former-commit-id: 2c6262c3cd
This commit is contained in:
hoshi-hiyouga
2024-10-08 17:31:17 +08:00
committed by GitHub
19 changed files with 775 additions and 240 deletions

View File

@@ -21,12 +21,12 @@ from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_ms
from .adapter import init_adapter
from .model_utils.liger_kernel import apply_liger_kernel
from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .model_utils.unsloth import load_unsloth_pretrained_model
from .model_utils.valuehead import load_valuehead_params
from .model_utils.visual import get_image_seqlen
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
if TYPE_CHECKING:
@@ -61,7 +61,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
r"""
Loads pretrained tokenizer.
Loads pretrained tokenizer and optionally loads processor.
Note: including inplace operation of model_args.
"""
@@ -96,15 +96,9 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
logger.warning("New tokens have been added, changed `resize_vocab` to True.")
patch_tokenizer(tokenizer)
try:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
setattr(processor, "tokenizer", tokenizer)
setattr(processor, "image_seqlen", get_image_seqlen(config))
setattr(processor, "image_resolution", model_args.image_resolution)
setattr(processor, "video_resolution", model_args.video_resolution)
setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_maxlen", model_args.video_maxlen)
patch_processor(processor, config, tokenizer, model_args)
except Exception as e:
logger.warning("Failed to load processor. Error: {}".format(e))
processor = None
@@ -138,6 +132,7 @@ def load_model(
init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))
model = None
lazy_load = False
@@ -158,7 +153,6 @@ def load_model(
load_class = AutoModelForVision2Seq
else:
load_class = AutoModelForCausalLM
if model_args.train_from_scratch:
model = load_class.from_config(config)
else:

View File

@@ -37,10 +37,11 @@ def configure_attn_implementation(
if is_flash_attn_2_available():
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3")
logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
model_args.flash_attn = "fa2"
if model_args.flash_attn != "fa2":
logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
model_args.flash_attn = "fa2"
else:
logger.warning("Gemma-2 should use eager attention, change `flash_attn` to disabled.")
logger.warning("FlashAttention-2 is not installed, use eager attention.")
model_args.flash_attn = "disabled"
elif model_args.flash_attn == "sdpa":
logger.warning("Gemma-2 should use soft-capping attention, while the SDPA attention does not support it.")

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import TYPE_CHECKING
from ...extras.logging import get_logger
@@ -26,7 +27,12 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
def apply_liger_kernel(
config: "PretrainedConfig",
model_args: "ModelArguments",
is_trainable: bool,
require_logits: bool,
) -> None:
if not is_trainable or not model_args.enable_liger_kernel:
return
@@ -51,5 +57,11 @@ def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArgumen
logger.warning("Current model does not support liger kernel.")
return
apply_liger_kernel()
if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters:
logger.info("Current training stage does not support chunked cross entropy.")
kwargs = {"fused_linear_cross_entropy": False}
else:
kwargs = {}
apply_liger_kernel(**kwargs)
logger.info("Liger kernel has been applied to the model.")

View File

@@ -34,7 +34,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
forbidden_modules.add("output_layer")
elif model_type == "internlm2":
forbidden_modules.add("output")
elif model_type in ["llava", "paligemma"]:
elif model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
forbidden_modules.add("multi_modal_projector")
elif model_type == "qwen2_vl":
forbidden_modules.add("merger")

View File

@@ -92,7 +92,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
if getattr(model, "quantization_method", None):
model_type = getattr(model.config, "model_type", None)
if model_type in ["llava", "paligemma"]:
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector")
elif model_type == "qwen2_vl":
mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger")
@@ -108,7 +108,13 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
Patches VLMs before loading them.
"""
model_type = getattr(config, "model_type", None)
if model_type == "llava": # required for ds zero3 and valuehead models
if model_type in [
"llava",
"llava_next",
"llava_next_video",
"paligemma",
"video_llava",
]: # required for ds zero3 and valuehead models
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
if getattr(config, "is_yi_vl_derived_model", None):
@@ -122,7 +128,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
"""
model_type = getattr(config, "model_type", None)
forbidden_modules = set()
if model_type in ["llava", "paligemma"]:
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
if finetuning_args.freeze_vision_tower:
forbidden_modules.add("vision_tower")
@@ -150,12 +156,28 @@ def get_image_seqlen(config: "PretrainedConfig") -> int:
image_seqlen += 1
elif model_type == "paligemma":
image_seqlen = config.vision_config.num_image_tokens
elif model_type == "qwen2_vl": # variable length
else:
image_seqlen = -1
return image_seqlen
def get_patch_size(config: "PretrainedConfig") -> int:
r"""
Computes the patch size of the vit.
"""
patch_size = getattr(config.vision_config, "patch_size", -1)
return patch_size
def get_vision_feature_select_strategy(config: "PretrainedConfig") -> int:
r"""
Get the vision_feature_select_strategy.
"""
vision_feature_select_strategy = getattr(config, "vision_feature_select_strategy", "default")
return vision_feature_select_strategy
def patch_target_modules(
config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
) -> Union[str, List[str]]:
@@ -164,7 +186,7 @@ def patch_target_modules(
"""
model_type = getattr(config, "model_type", None)
if finetuning_args.freeze_vision_tower:
if model_type in ["llava", "paligemma"]:
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
elif model_type == "qwen2_vl":
return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules))

View File

@@ -27,18 +27,23 @@ from ..extras.misc import infer_optim_dtype
from .model_utils.attention import configure_attn_implementation, print_attn_implementation
from .model_utils.checkpointing import prepare_model_for_training
from .model_utils.embedding import resize_embedding_layer
from .model_utils.liger_kernel import configure_liger_kernel
from .model_utils.longlora import configure_longlora
from .model_utils.moe import add_z3_leaf_module, configure_moe
from .model_utils.packing import configure_packing
from .model_utils.quantization import configure_quantization
from .model_utils.rope import configure_rope
from .model_utils.valuehead import prepare_valuehead_model
from .model_utils.visual import autocast_projector_dtype, configure_visual_model
from .model_utils.visual import (
autocast_projector_dtype,
configure_visual_model,
get_image_seqlen,
get_patch_size,
get_vision_feature_select_strategy,
)
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from transformers import PretrainedConfig, PreTrainedTokenizer, ProcessorMixin
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import ModelArguments
@@ -52,6 +57,22 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
def patch_processor(
processor: "ProcessorMixin",
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
) -> None:
setattr(processor, "tokenizer", tokenizer)
setattr(processor, "image_seqlen", get_image_seqlen(config))
setattr(processor, "image_resolution", model_args.image_resolution)
setattr(processor, "patch_size", get_patch_size(config))
setattr(processor, "video_resolution", model_args.video_resolution)
setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_maxlen", model_args.video_maxlen)
setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config))
def patch_config(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
@@ -71,7 +92,6 @@ def patch_config(
configure_attn_implementation(config, model_args, is_trainable)
configure_rope(config, model_args, is_trainable)
configure_liger_kernel(config, model_args, is_trainable)
configure_longlora(config, model_args, is_trainable)
configure_quantization(config, tokenizer, model_args, init_kwargs)
configure_moe(config, model_args, is_trainable)
@@ -90,6 +110,9 @@ def patch_config(
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
# deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())