mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	support qwen2vl vllm infer
Former-commit-id: 03ddd2555fb97488cd4daab11e8b672d36150c5a
This commit is contained in:
		
							parent
							
								
									1fef702382
								
							
						
					
					
						commit
						bbd432415d
					
				@ -24,7 +24,7 @@ from torch.utils.data import DataLoader
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
from transformers import DataCollatorForLanguageModeling
 | 
			
		||||
 | 
			
		||||
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer, MultiModalDataCollatorForSeq2Seq
 | 
			
		||||
from llamafactory.data import MultiModalDataCollatorForSeq2Seq, get_dataset, get_template_and_fix_tokenizer
 | 
			
		||||
from llamafactory.extras.constants import IGNORE_INDEX
 | 
			
		||||
from llamafactory.hparams import get_train_args
 | 
			
		||||
from llamafactory.model import load_tokenizer
 | 
			
		||||
@ -71,7 +71,9 @@ def calculate_lr(
 | 
			
		||||
    if stage == "pt":
 | 
			
		||||
        data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
 | 
			
		||||
    elif stage == "sft":
 | 
			
		||||
        data_collator = MultiModalDataCollatorForSeq2Seq(template=template, tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
 | 
			
		||||
        data_collator = MultiModalDataCollatorForSeq2Seq(
 | 
			
		||||
            template=template, tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        raise NotImplementedError(f"Stage does not supported: {stage}.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -16,16 +16,25 @@ import json
 | 
			
		||||
 | 
			
		||||
import fire
 | 
			
		||||
from transformers import Seq2SeqTrainingArguments
 | 
			
		||||
from vllm import LLM, SamplingParams
 | 
			
		||||
from vllm.lora.request import LoRARequest
 | 
			
		||||
 | 
			
		||||
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
 | 
			
		||||
from llamafactory.extras.constants import IGNORE_INDEX
 | 
			
		||||
from llamafactory.extras.misc import get_device_count
 | 
			
		||||
from llamafactory.extras.packages import is_pillow_available, is_vllm_available
 | 
			
		||||
from llamafactory.hparams import get_infer_args
 | 
			
		||||
from llamafactory.model import load_tokenizer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if is_pillow_available():
 | 
			
		||||
    from PIL import Image
 | 
			
		||||
    from PIL.Image import Image as ImageObject
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if is_vllm_available():
 | 
			
		||||
    from vllm import LLM, SamplingParams
 | 
			
		||||
    from vllm.lora.request import LoRARequest
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def vllm_infer(
 | 
			
		||||
    model_name_or_path: str,
 | 
			
		||||
    adapter_name_or_path: str = None,
 | 
			
		||||
@ -64,15 +73,29 @@ def vllm_infer(
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    training_args = Seq2SeqTrainingArguments(output_dir="dummy_dir", predict_with_generate=True)
 | 
			
		||||
    training_args = Seq2SeqTrainingArguments(output_dir="dummy_dir")
 | 
			
		||||
    tokenizer_module = load_tokenizer(model_args)
 | 
			
		||||
    tokenizer = tokenizer_module["tokenizer"]
 | 
			
		||||
    template = get_template_and_fix_tokenizer(tokenizer, data_args)
 | 
			
		||||
    dataset = get_dataset(template, model_args, data_args, training_args, "ppo", **tokenizer_module)["train_dataset"]
 | 
			
		||||
    template_obj = get_template_and_fix_tokenizer(tokenizer, data_args)
 | 
			
		||||
    template_obj.mm_plugin.expand_mm_tokens = False  # for vllm generate
 | 
			
		||||
    dataset_module = get_dataset(template_obj, model_args, data_args, training_args, "ppo", **tokenizer_module)
 | 
			
		||||
 | 
			
		||||
    inputs, prompts, labels = [], [], []
 | 
			
		||||
    for sample in dataset:
 | 
			
		||||
        inputs.append({"prompt_token_ids": sample["input_ids"]})
 | 
			
		||||
    for sample in dataset_module["train_dataset"]:
 | 
			
		||||
        if sample["images"]:
 | 
			
		||||
            multi_modal_data = {"image": []}
 | 
			
		||||
            for image in sample["images"]:
 | 
			
		||||
                if not isinstance(image, (str, ImageObject)):
 | 
			
		||||
                    raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
 | 
			
		||||
 | 
			
		||||
                if isinstance(image, str):
 | 
			
		||||
                    image = Image.open(image).convert("RGB")
 | 
			
		||||
 | 
			
		||||
                multi_modal_data["image"].append(image)
 | 
			
		||||
        else:
 | 
			
		||||
            multi_modal_data = None
 | 
			
		||||
 | 
			
		||||
        inputs.append({"prompt_token_ids": sample["input_ids"], "multi_modal_data": multi_modal_data})
 | 
			
		||||
        prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=False))
 | 
			
		||||
        labels.append(
 | 
			
		||||
            tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=False)
 | 
			
		||||
@ -100,6 +123,9 @@ def vllm_infer(
 | 
			
		||||
        "disable_log_stats": True,
 | 
			
		||||
        "enable_lora": model_args.adapter_name_or_path is not None,
 | 
			
		||||
    }
 | 
			
		||||
    if template_obj.mm_plugin.__class__.__name__ != "BasePlugin":
 | 
			
		||||
        engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2}
 | 
			
		||||
 | 
			
		||||
    if isinstance(model_args.vllm_config, dict):
 | 
			
		||||
        engine_args.update(model_args.vllm_config)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,7 @@ from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
from ..data import get_template_and_fix_tokenizer
 | 
			
		||||
from ..extras import logging
 | 
			
		||||
from ..extras.constants import IMAGE_PLACEHOLDER
 | 
			
		||||
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
 | 
			
		||||
from ..extras.misc import get_device_count
 | 
			
		||||
from ..extras.packages import is_pillow_available, is_vllm_available
 | 
			
		||||
from ..model import load_config, load_tokenizer
 | 
			
		||||
@ -67,6 +67,7 @@ class VllmEngine(BaseEngine):
 | 
			
		||||
        self.processor = tokenizer_module["processor"]
 | 
			
		||||
        self.tokenizer.padding_side = "left"
 | 
			
		||||
        self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
 | 
			
		||||
        self.template.mm_plugin.expand_mm_tokens = False  # for vllm generate
 | 
			
		||||
        self.generating_args = generating_args.to_dict()
 | 
			
		||||
 | 
			
		||||
        engine_args = {
 | 
			
		||||
@ -83,6 +84,9 @@ class VllmEngine(BaseEngine):
 | 
			
		||||
            "enable_lora": model_args.adapter_name_or_path is not None,
 | 
			
		||||
            "max_lora_rank": model_args.vllm_max_lora_rank,
 | 
			
		||||
        }
 | 
			
		||||
        if self.template.mm_plugin.__class__.__name__ != "BasePlugin":
 | 
			
		||||
            engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2}
 | 
			
		||||
 | 
			
		||||
        if isinstance(model_args.vllm_config, dict):
 | 
			
		||||
            engine_args.update(model_args.vllm_config)
 | 
			
		||||
 | 
			
		||||
@ -108,19 +112,21 @@ class VllmEngine(BaseEngine):
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> AsyncIterator["RequestOutput"]:
 | 
			
		||||
        request_id = f"chatcmpl-{uuid.uuid4().hex}"
 | 
			
		||||
        mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
 | 
			
		||||
        if images is not None:
 | 
			
		||||
            mm_input_dict.update({"images": images, "imglens": [len(images)]})
 | 
			
		||||
            if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
 | 
			
		||||
                messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
 | 
			
		||||
 | 
			
		||||
        if self.template.mm_plugin.__class__.__name__ == "Qwen2vlPlugin":  # temporary solution
 | 
			
		||||
            image_str = f"<|vision_start|>{self.template.mm_plugin.image_token}<|vision_end|>"
 | 
			
		||||
        else:
 | 
			
		||||
            image_str = self.template.mm_plugin.image_token or ""
 | 
			
		||||
        if videos is not None:
 | 
			
		||||
            mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
 | 
			
		||||
            if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
 | 
			
		||||
                messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
 | 
			
		||||
 | 
			
		||||
        paired_messages = [
 | 
			
		||||
            {"role": message["role"], "content": message["content"].replace(IMAGE_PLACEHOLDER, image_str)}
 | 
			
		||||
            for message in messages
 | 
			
		||||
        ] + [{"role": "assistant", "content": ""}]
 | 
			
		||||
        messages = self.template.mm_plugin.process_messages(
 | 
			
		||||
            messages, mm_input_dict["images"], mm_input_dict["videos"], self.processor
 | 
			
		||||
        )
 | 
			
		||||
        paired_messages = messages + [{"role": "assistant", "content": ""}]
 | 
			
		||||
        system = system or self.generating_args["default_system"]
 | 
			
		||||
        prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
 | 
			
		||||
        prompt_length = len(prompt_ids)
 | 
			
		||||
@ -168,7 +174,7 @@ class VllmEngine(BaseEngine):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if images is not None:  # add image features
 | 
			
		||||
            image_data = []
 | 
			
		||||
            multi_modal_data = {"image": []}
 | 
			
		||||
            for image in images:
 | 
			
		||||
                if not isinstance(image, (str, ImageObject)):
 | 
			
		||||
                    raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
 | 
			
		||||
@ -176,9 +182,7 @@ class VllmEngine(BaseEngine):
 | 
			
		||||
                if isinstance(image, str):
 | 
			
		||||
                    image = Image.open(image).convert("RGB")
 | 
			
		||||
 | 
			
		||||
                image_data.append(image)
 | 
			
		||||
 | 
			
		||||
            multi_modal_data = {"image": image_data}
 | 
			
		||||
                multi_modal_data["image"].append(image)
 | 
			
		||||
        else:
 | 
			
		||||
            multi_modal_data = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -62,6 +62,7 @@ class BasePlugin:
 | 
			
		||||
    def __init__(self, image_token: Optional[str], video_token: Optional[str]) -> None:
 | 
			
		||||
        self.image_token = image_token
 | 
			
		||||
        self.video_token = video_token
 | 
			
		||||
        self.expand_mm_tokens = True
 | 
			
		||||
 | 
			
		||||
    def _validate_input(
 | 
			
		||||
        self,
 | 
			
		||||
@ -259,7 +260,7 @@ class LlavaPlugin(BasePlugin):
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        num_image_tokens = 0
 | 
			
		||||
        image_seqlen = getattr(processor, "image_seqlen")
 | 
			
		||||
        image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        for message in messages:
 | 
			
		||||
            content = message["content"]
 | 
			
		||||
@ -310,11 +311,13 @@ class LlavaNextPlugin(BasePlugin):
 | 
			
		||||
        for message in messages:
 | 
			
		||||
            content = message["content"]
 | 
			
		||||
            while IMAGE_PLACEHOLDER in content:
 | 
			
		||||
                image_size = next(image_sizes)
 | 
			
		||||
                orig_height, orig_width = image_size
 | 
			
		||||
                image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
 | 
			
		||||
                if getattr(processor, "vision_feature_select_strategy") == "default":
 | 
			
		||||
                    image_seqlen -= 1
 | 
			
		||||
                if self.expand_mm_tokens:
 | 
			
		||||
                    orig_height, orig_width = next(image_sizes)
 | 
			
		||||
                    image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
 | 
			
		||||
                    if getattr(processor, "vision_feature_select_strategy") == "default":
 | 
			
		||||
                        image_seqlen -= 1
 | 
			
		||||
                else:
 | 
			
		||||
                    image_seqlen = 1
 | 
			
		||||
 | 
			
		||||
                num_image_tokens += 1
 | 
			
		||||
                content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
 | 
			
		||||
@ -359,11 +362,13 @@ class LlavaNextVideoPlugin(BasePlugin):
 | 
			
		||||
            for message in messages:
 | 
			
		||||
                content = message["content"]
 | 
			
		||||
                while IMAGE_PLACEHOLDER in content:
 | 
			
		||||
                    image_size = next(image_sizes)
 | 
			
		||||
                    orig_height, orig_width = image_size
 | 
			
		||||
                    image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
 | 
			
		||||
                    if getattr(processor, "vision_feature_select_strategy") == "default":
 | 
			
		||||
                        image_seqlen -= 1
 | 
			
		||||
                    if self.expand_mm_tokens:
 | 
			
		||||
                        orig_height, orig_width = next(image_sizes)
 | 
			
		||||
                        image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
 | 
			
		||||
                        if getattr(processor, "vision_feature_select_strategy") == "default":
 | 
			
		||||
                            image_seqlen -= 1
 | 
			
		||||
                    else:
 | 
			
		||||
                        image_seqlen = 1
 | 
			
		||||
 | 
			
		||||
                    num_image_tokens += 1
 | 
			
		||||
                    content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
 | 
			
		||||
@ -376,6 +381,7 @@ class LlavaNextVideoPlugin(BasePlugin):
 | 
			
		||||
            num_frames = pixel_values_video.shape[0]  # frame dim is always after batch dim
 | 
			
		||||
            image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
 | 
			
		||||
            video_seqlen = image_seqlen // 4 * num_frames  # divide by 4 needed for avg pooling layer
 | 
			
		||||
            video_seqlen = video_seqlen if self.expand_mm_tokens else 1
 | 
			
		||||
            for message in messages:
 | 
			
		||||
                content = message["content"]
 | 
			
		||||
                while VIDEO_PLACEHOLDER in content:
 | 
			
		||||
@ -443,7 +449,7 @@ class PaliGemmaPlugin(BasePlugin):
 | 
			
		||||
    ) -> Tuple[List[int], Optional[List[int]]]:
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        num_images = len(images)
 | 
			
		||||
        image_seqlen = num_images * getattr(processor, "image_seqlen")
 | 
			
		||||
        image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0  # skip mm token
 | 
			
		||||
        image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
 | 
			
		||||
        input_ids = [image_token_id] * image_seqlen + input_ids
 | 
			
		||||
        if labels is not None:
 | 
			
		||||
@ -493,14 +499,18 @@ class PixtralPlugin(BasePlugin):
 | 
			
		||||
                if image_input_sizes is None:
 | 
			
		||||
                    raise ValueError("Cannot get image input sizes.")
 | 
			
		||||
 | 
			
		||||
                image_size = image_input_sizes[0][num_image_tokens]
 | 
			
		||||
                height, width = image_size
 | 
			
		||||
                num_height_tokens = height // patch_size
 | 
			
		||||
                num_width_tokens = width // patch_size
 | 
			
		||||
                replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
 | 
			
		||||
                replace_tokens = [item for sublist in replace_tokens for item in sublist]  # flatten list
 | 
			
		||||
                replace_tokens[-1] = image_end_token
 | 
			
		||||
                replace_str = "".join(replace_tokens)
 | 
			
		||||
                if self.expand_mm_tokens:
 | 
			
		||||
                    image_size = image_input_sizes[0][num_image_tokens]
 | 
			
		||||
                    height, width = image_size
 | 
			
		||||
                    num_height_tokens = height // patch_size
 | 
			
		||||
                    num_width_tokens = width // patch_size
 | 
			
		||||
                    replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
 | 
			
		||||
                    replace_tokens = [item for sublist in replace_tokens for item in sublist]  # flatten list
 | 
			
		||||
                    replace_tokens[-1] = image_end_token
 | 
			
		||||
                    replace_str = "".join(replace_tokens)
 | 
			
		||||
                else:
 | 
			
		||||
                    replace_str = image_token
 | 
			
		||||
 | 
			
		||||
                content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
 | 
			
		||||
                num_image_tokens += 1
 | 
			
		||||
 | 
			
		||||
@ -549,10 +559,27 @@ class Qwen2vlPlugin(BasePlugin):
 | 
			
		||||
        return image
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int:
 | 
			
		||||
        sample_frames = super()._get_video_sample_frames(video_stream, **kwargs)
 | 
			
		||||
        sample_frames = sample_frames // 2 * 2
 | 
			
		||||
        return sample_frames
 | 
			
		||||
    def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
 | 
			
		||||
        results = []
 | 
			
		||||
        for video in videos:
 | 
			
		||||
            container = av.open(video, "r")
 | 
			
		||||
            video_stream = next(stream for stream in container.streams if stream.type == "video")
 | 
			
		||||
            total_frames = video_stream.frames
 | 
			
		||||
            sample_frames = self._get_video_sample_frames(video_stream, **kwargs)
 | 
			
		||||
            sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
 | 
			
		||||
            frames: List["ImageObject"] = []
 | 
			
		||||
            container.seek(0)
 | 
			
		||||
            for frame_idx, frame in enumerate(container.decode(video_stream)):
 | 
			
		||||
                if frame_idx in sample_indices:
 | 
			
		||||
                    frames.append(frame.to_image())
 | 
			
		||||
 | 
			
		||||
            if len(frames) % 2 != 0:  # qwen2-vl requires even number of frames
 | 
			
		||||
                frames.append(frames[-1])
 | 
			
		||||
 | 
			
		||||
            frames = self._regularize_images(frames, **kwargs)
 | 
			
		||||
            results.append(frames)
 | 
			
		||||
 | 
			
		||||
        return results
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def process_messages(
 | 
			
		||||
@ -577,12 +604,9 @@ class Qwen2vlPlugin(BasePlugin):
 | 
			
		||||
                if num_image_tokens >= len(image_grid_thw):
 | 
			
		||||
                    raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
 | 
			
		||||
 | 
			
		||||
                image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
 | 
			
		||||
                content = content.replace(
 | 
			
		||||
                    IMAGE_PLACEHOLDER,
 | 
			
		||||
                    "<|vision_start|>{}<|vision_end|>".format(
 | 
			
		||||
                        self.image_token * (image_grid_thw[num_image_tokens].prod() // merge_length)
 | 
			
		||||
                    ),
 | 
			
		||||
                    1,
 | 
			
		||||
                    IMAGE_PLACEHOLDER, f"<|vision_start|>{self.image_token * image_seqlen}<|vision_end|>", 1
 | 
			
		||||
                )
 | 
			
		||||
                num_image_tokens += 1
 | 
			
		||||
 | 
			
		||||
@ -590,12 +614,9 @@ class Qwen2vlPlugin(BasePlugin):
 | 
			
		||||
                if num_video_tokens >= len(video_grid_thw):
 | 
			
		||||
                    raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
 | 
			
		||||
 | 
			
		||||
                video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
 | 
			
		||||
                content = content.replace(
 | 
			
		||||
                    VIDEO_PLACEHOLDER,
 | 
			
		||||
                    "<|vision_start|>{}<|vision_end|>".format(
 | 
			
		||||
                        self.video_token * (video_grid_thw[num_video_tokens].prod() // merge_length)
 | 
			
		||||
                    ),
 | 
			
		||||
                    1,
 | 
			
		||||
                    VIDEO_PLACEHOLDER, f"<|vision_start|>{self.video_token * video_seqlen}<|vision_end|>", 1
 | 
			
		||||
                )
 | 
			
		||||
                num_video_tokens += 1
 | 
			
		||||
 | 
			
		||||
@ -640,19 +661,22 @@ class VideoLlavaPlugin(BasePlugin):
 | 
			
		||||
        has_images = "pixel_values_images" in mm_inputs
 | 
			
		||||
        has_videos = "pixel_values_videos" in mm_inputs
 | 
			
		||||
        if has_images or has_videos:
 | 
			
		||||
            if has_images:
 | 
			
		||||
                height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
 | 
			
		||||
                num_frames = 1
 | 
			
		||||
            if self.expand_mm_tokens:
 | 
			
		||||
                if has_images:
 | 
			
		||||
                    height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
 | 
			
		||||
                    num_frames = 1
 | 
			
		||||
 | 
			
		||||
            if has_videos:
 | 
			
		||||
                pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
 | 
			
		||||
                height, width = get_image_size(pixel_values_video[0])
 | 
			
		||||
                num_frames = pixel_values_video.shape[0]  # frame dim is always after batch dim
 | 
			
		||||
                if has_videos:
 | 
			
		||||
                    pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
 | 
			
		||||
                    height, width = get_image_size(pixel_values_video[0])
 | 
			
		||||
                    num_frames = pixel_values_video.shape[0]  # frame dim is always after batch dim
 | 
			
		||||
 | 
			
		||||
            image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
 | 
			
		||||
            video_seqlen = image_seqlen * num_frames
 | 
			
		||||
            if getattr(processor, "vision_feature_select_strategy") == "default":
 | 
			
		||||
                image_seqlen -= 1
 | 
			
		||||
                image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
 | 
			
		||||
                video_seqlen = image_seqlen * num_frames
 | 
			
		||||
                if getattr(processor, "vision_feature_select_strategy") == "default":
 | 
			
		||||
                    image_seqlen -= 1
 | 
			
		||||
            else:
 | 
			
		||||
                image_seqlen, video_seqlen = 1, 1
 | 
			
		||||
 | 
			
		||||
            for message in messages:
 | 
			
		||||
                content = message["content"]
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user