[data] fix mm pluigin for qwen omni video training (#9388)

Co-authored-by: frozenleaves <frozen@Mac.local>
This commit is contained in:
魅影 2025-11-03 11:44:27 +08:00 committed by GitHub
parent 767b344fb4
commit 215580c77d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 3 deletions

View File

@ -68,6 +68,8 @@ if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
from transformers.image_processing_utils import BaseImageProcessor
from transformers.video_processing_utils import BaseVideoProcessor
class EncodedImage(TypedDict):
path: Optional[str]
@ -1482,6 +1484,7 @@ class Qwen2VLPlugin(BasePlugin):
processor: "MMProcessor",
) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
@ -1499,7 +1502,7 @@ class Qwen2VLPlugin(BasePlugin):
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128),
)
mm_inputs.update(image_processor(images=None, videos=video_data["videos"], return_tensors="pt"))
mm_inputs.update(video_processor(videos=video_data["videos"], return_tensors="pt"))
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
if "second_per_grid_ts" in processor.model_input_names:
mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in video_data["fps_per_video"]]
@ -1818,6 +1821,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
processor: "MMProcessor",
) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
mm_inputs = {}
if len(images) != 0:
@ -1836,7 +1840,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128),
)
mm_inputs.update(image_processor(images=None, videos=video_dict["videos"], return_tensors="pt"))
mm_inputs.update(video_processor(videos=video_dict["videos"], return_tensors="pt"))
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
mm_inputs["video_second_per_grid"] = torch.tensor(
[temporal_patch_size / fps for fps in video_dict["fps_per_video"]]

View File

@ -57,7 +57,7 @@ def launch():
if is_env_enabled("USE_MCA"):
# force use torchrun
os.environ["FORCE_TORCHRUN"] = "1"
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
# launch distributed training
nnodes = os.getenv("NNODES", "1")