mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 19:52:50 +08:00
[data] fix qwen2audio plugin (#7166)
* Update pairwise.py [data]Repair multimodal model dpo training * Update pairwise.py [data]repair multimodal model dpo training using deepcopy * Update pairwise.py * Update mm_plugin.py Former-commit-id: dff4130969bac9cb1abe66fd5dfada8c757c716f
This commit is contained in:
parent
e1d574a784
commit
8dddffa340
@ -993,6 +993,7 @@ class Qwen2AudioPlugin(BasePlugin):
|
|||||||
self._validate_input(processor, images, videos, audios)
|
self._validate_input(processor, images, videos, audios)
|
||||||
bos_token: str = getattr(processor, "audio_bos_token")
|
bos_token: str = getattr(processor, "audio_bos_token")
|
||||||
eos_token: str = getattr(processor, "audio_eos_token")
|
eos_token: str = getattr(processor, "audio_eos_token")
|
||||||
|
messages = deepcopy(messages)
|
||||||
mm_inputs = self._get_mm_inputs([], [], audios, processor)
|
mm_inputs = self._get_mm_inputs([], [], audios, processor)
|
||||||
if "feature_attention_mask" in mm_inputs:
|
if "feature_attention_mask" in mm_inputs:
|
||||||
audio_lengths = mm_inputs["feature_attention_mask"].sum(-1).tolist()
|
audio_lengths = mm_inputs["feature_attention_mask"].sum(-1).tolist()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user