mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-11-04 01:42:14 +08:00
[data] fix mm pluigin for qwen omni video training (#9388)
Co-authored-by: frozenleaves <frozen@Mac.local>
This commit is contained in:
parent
767b344fb4
commit
215580c77d
@ -68,6 +68,8 @@ if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
from transformers.video_processing_utils import BaseVideoProcessor
|
||||
|
||||
|
||||
class EncodedImage(TypedDict):
|
||||
path: Optional[str]
|
||||
@ -1482,6 +1484,7 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
processor: "MMProcessor",
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
|
||||
mm_inputs = {}
|
||||
if len(images) != 0:
|
||||
images = self._regularize_images(
|
||||
@ -1499,7 +1502,7 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
video_fps=getattr(processor, "video_fps", 2.0),
|
||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||
)
|
||||
mm_inputs.update(image_processor(images=None, videos=video_data["videos"], return_tensors="pt"))
|
||||
mm_inputs.update(video_processor(videos=video_data["videos"], return_tensors="pt"))
|
||||
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
|
||||
if "second_per_grid_ts" in processor.model_input_names:
|
||||
mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in video_data["fps_per_video"]]
|
||||
@ -1818,6 +1821,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
processor: "MMProcessor",
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
||||
mm_inputs = {}
|
||||
if len(images) != 0:
|
||||
@ -1836,7 +1840,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
video_fps=getattr(processor, "video_fps", 2.0),
|
||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||
)
|
||||
mm_inputs.update(image_processor(images=None, videos=video_dict["videos"], return_tensors="pt"))
|
||||
mm_inputs.update(video_processor(videos=video_dict["videos"], return_tensors="pt"))
|
||||
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
|
||||
mm_inputs["video_second_per_grid"] = torch.tensor(
|
||||
[temporal_patch_size / fps for fps in video_dict["fps_per_video"]]
|
||||
|
||||
@ -57,7 +57,7 @@ def launch():
|
||||
if is_env_enabled("USE_MCA"):
|
||||
# force use torchrun
|
||||
os.environ["FORCE_TORCHRUN"] = "1"
|
||||
|
||||
|
||||
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
|
||||
# launch distributed training
|
||||
nnodes = os.getenv("NNODES", "1")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user