[model] support youtu-vl model (#10152)

This commit is contained in:
Hertz
2026-02-02 21:42:43 +08:00
committed by GitHub
parent bf04ca6af8
commit b53d7037c2
5 changed files with 95 additions and 0 deletions

View File

@@ -61,6 +61,26 @@ def patch_qwen3_omni_moe_thinker_text_sparse_moe_block():
modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock = Qwen3OmniMoeThinkerTextSparseMoeBlock
def patch_youtu_vl_model(model: "PreTrainedModel") -> None:
original_forward = model.forward
def forward(self, *args, **kwargs):
outputs = original_forward(*args, **kwargs)
if "loss" not in outputs and "labels" in kwargs:
logits = outputs.get("logits")
labels = kwargs.get("labels")
if logits is not None and labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
outputs["loss"] = loss
return outputs
model.forward = MethodType(forward, model)
def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> None:
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
@@ -207,6 +227,9 @@ def patch_model(
if getattr(model.config, "model_type", None) == "gemma3n":
setattr(model_args, "disable_gradient_checkpointing", True)
if getattr(model.config, "model_type", None) == "youtu_vl":
patch_youtu_vl_model(model)
prepare_model_for_training(model, model_args)
autocast_projector_dtype(model, model_args)
add_z3_leaf_module(model)