mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
[model] fix dsv3 leaf node (#7879)
This commit is contained in:
parent
00b5c05946
commit
1f338deb87
@ -1671,7 +1671,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
|||||||
if num_video_tokens >= len(videos):
|
if num_video_tokens >= len(videos):
|
||||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
video_seqlen = (
|
||||||
|
video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||||
|
)
|
||||||
content = content.replace(
|
content = content.replace(
|
||||||
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_seqlen}<|vision_eos|>", 1
|
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_seqlen}<|vision_eos|>", 1
|
||||||
)
|
)
|
||||||
|
@ -12,21 +12,21 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Union
|
||||||
|
|
||||||
import torch
|
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from ...extras.misc import check_version
|
from ...extras.misc import check_version
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from torch import nn
|
||||||
from transformers import PretrainedConfig, PreTrainedModel
|
from transformers import PretrainedConfig, PreTrainedModel
|
||||||
|
|
||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: list["torch.nn.Module"]) -> None:
|
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: list[Union["nn.Module", str]]) -> None:
|
||||||
check_version("deepspeed>=0.13.0")
|
check_version("deepspeed>=0.13.0")
|
||||||
from deepspeed.utils import set_z3_leaf_modules # type: ignore
|
from deepspeed.utils import set_z3_leaf_modules # type: ignore
|
||||||
|
|
||||||
@ -44,10 +44,13 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
|||||||
|
|
||||||
_set_z3_leaf_modules(model, [DbrxFFN])
|
_set_z3_leaf_modules(model, [DbrxFFN])
|
||||||
|
|
||||||
if model_type == "deepseek_v3":
|
if model_type == "deepseek_v2":
|
||||||
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
|
# deepseek v2 uses custom code
|
||||||
|
_set_z3_leaf_modules(model, ["DeepseekV2MoE"])
|
||||||
|
|
||||||
_set_z3_leaf_modules(model, [DeepseekV3MoE])
|
if model_type == "deepseek_v3" or model_type == "kimi_vl":
|
||||||
|
# deepseek v3 and kimi vl use custom code
|
||||||
|
_set_z3_leaf_modules(model, ["DeepseekV3MoE"])
|
||||||
|
|
||||||
if model_type == "granitemoe":
|
if model_type == "granitemoe":
|
||||||
from transformers.models.granitemoe.modeling_granitemoe import GraniteMoeMoE
|
from transformers.models.granitemoe.modeling_granitemoe import GraniteMoeMoE
|
||||||
|
Loading…
x
Reference in New Issue
Block a user