mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-07 04:05:58 +08:00
[model] support youtu-vl model (#10152)
This commit is contained in:
@@ -57,6 +57,11 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
|
||||
"Gemma-2 should use soft-capping attention, while the SDPA attention does not support it."
|
||||
)
|
||||
|
||||
if getattr(config, "model_type", None) in ["youtu", "youtu_vl"]:
|
||||
if model_args.flash_attn in (AttentionFunction.AUTO, AttentionFunction.SDPA):
|
||||
logger.warning_rank0("Youtu-VL does not support SDPA, forcing eager attention.")
|
||||
model_args.flash_attn = AttentionFunction.DISABLED
|
||||
|
||||
if model_args.flash_attn == AttentionFunction.AUTO:
|
||||
return
|
||||
|
||||
@@ -85,6 +90,13 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
|
||||
elif getattr(config, "model_type", None) == "kimi_vl":
|
||||
setattr(config.vision_config, "_attn_implementation", requested_attn_implementation)
|
||||
setattr(config.text_config, "_attn_implementation", requested_attn_implementation)
|
||||
elif getattr(config, "model_type", None) == "youtu_vl":
|
||||
setattr(config, "attn_implementation", requested_attn_implementation)
|
||||
setattr(config, "_attn_implementation", requested_attn_implementation)
|
||||
if hasattr(config, "vision_config"):
|
||||
setattr(config.vision_config, "_attn_implementation", requested_attn_implementation)
|
||||
if hasattr(config, "text_config"):
|
||||
setattr(config.text_config, "_attn_implementation", requested_attn_implementation)
|
||||
else:
|
||||
setattr(config, "_attn_implementation", requested_attn_implementation)
|
||||
|
||||
|
||||
@@ -61,6 +61,26 @@ def patch_qwen3_omni_moe_thinker_text_sparse_moe_block():
|
||||
modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock = Qwen3OmniMoeThinkerTextSparseMoeBlock
|
||||
|
||||
|
||||
def patch_youtu_vl_model(model: "PreTrainedModel") -> None:
|
||||
original_forward = model.forward
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
outputs = original_forward(*args, **kwargs)
|
||||
if "loss" not in outputs and "labels" in kwargs:
|
||||
logits = outputs.get("logits")
|
||||
labels = kwargs.get("labels")
|
||||
if logits is not None and labels is not None:
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
loss_fct = torch.nn.CrossEntropyLoss()
|
||||
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
|
||||
outputs["loss"] = loss
|
||||
|
||||
return outputs
|
||||
|
||||
model.forward = MethodType(forward, model)
|
||||
|
||||
|
||||
def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> None:
|
||||
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
|
||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||
@@ -207,6 +227,9 @@ def patch_model(
|
||||
if getattr(model.config, "model_type", None) == "gemma3n":
|
||||
setattr(model_args, "disable_gradient_checkpointing", True)
|
||||
|
||||
if getattr(model.config, "model_type", None) == "youtu_vl":
|
||||
patch_youtu_vl_model(model)
|
||||
|
||||
prepare_model_for_training(model, model_args)
|
||||
autocast_projector_dtype(model, model_args)
|
||||
add_z3_leaf_module(model)
|
||||
|
||||
Reference in New Issue
Block a user