[model] fix dsv3 leaf node (#7879)

This commit is contained in:
hoshi-hiyouga 2025-04-28 18:11:09 +08:00 committed by GitHub
parent 00b5c05946
commit 1f338deb87
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 7 deletions

View File

@ -1671,7 +1671,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
if num_video_tokens >= len(videos):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
video_seqlen = (
video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
)
content = content.replace(
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_seqlen}<|vision_eos|>", 1
)

View File

@ -12,21 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union
import torch
from transformers.integrations import is_deepspeed_zero3_enabled
from ...extras.misc import check_version
if TYPE_CHECKING:
from torch import nn
from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: list["torch.nn.Module"]) -> None:
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: list[Union["nn.Module", str]]) -> None:
check_version("deepspeed>=0.13.0")
from deepspeed.utils import set_z3_leaf_modules # type: ignore
@ -44,10 +44,13 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
_set_z3_leaf_modules(model, [DbrxFFN])
if model_type == "deepseek_v3":
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
if model_type == "deepseek_v2":
# deepseek v2 uses custom code
_set_z3_leaf_modules(model, ["DeepseekV2MoE"])
_set_z3_leaf_modules(model, [DeepseekV3MoE])
if model_type == "deepseek_v3" or model_type == "kimi_vl":
# deepseek v3 and kimi vl use custom code
_set_z3_leaf_modules(model, ["DeepseekV3MoE"])
if model_type == "granitemoe":
from transformers.models.granitemoe.modeling_granitemoe import GraniteMoeMoE