mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
[model] fix dsv3 leaf node (#7879)
This commit is contained in:
parent
00b5c05946
commit
1f338deb87
@ -1671,7 +1671,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
if num_video_tokens >= len(videos):
|
||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
video_seqlen = (
|
||||
video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
)
|
||||
content = content.replace(
|
||||
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_seqlen}<|vision_eos|>", 1
|
||||
)
|
||||
|
@ -12,21 +12,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
from ...extras.misc import check_version
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: list["torch.nn.Module"]) -> None:
|
||||
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: list[Union["nn.Module", str]]) -> None:
|
||||
check_version("deepspeed>=0.13.0")
|
||||
from deepspeed.utils import set_z3_leaf_modules # type: ignore
|
||||
|
||||
@ -44,10 +44,13 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
||||
|
||||
_set_z3_leaf_modules(model, [DbrxFFN])
|
||||
|
||||
if model_type == "deepseek_v3":
|
||||
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
|
||||
if model_type == "deepseek_v2":
|
||||
# deepseek v2 uses custom code
|
||||
_set_z3_leaf_modules(model, ["DeepseekV2MoE"])
|
||||
|
||||
_set_z3_leaf_modules(model, [DeepseekV3MoE])
|
||||
if model_type == "deepseek_v3" or model_type == "kimi_vl":
|
||||
# deepseek v3 and kimi vl use custom code
|
||||
_set_z3_leaf_modules(model, ["DeepseekV3MoE"])
|
||||
|
||||
if model_type == "granitemoe":
|
||||
from transformers.models.granitemoe.modeling_granitemoe import GraniteMoeMoE
|
||||
|
Loading…
x
Reference in New Issue
Block a user