mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
support qwen2vl vllm infer
Former-commit-id: 207f8b069ca35a28de4588b4962e7254f451c52c
This commit is contained in:
parent
7f8c59144e
commit
88b06a0c7f
@ -24,7 +24,7 @@ from torch.utils.data import DataLoader
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import DataCollatorForLanguageModeling
|
from transformers import DataCollatorForLanguageModeling
|
||||||
|
|
||||||
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer, MultiModalDataCollatorForSeq2Seq
|
from llamafactory.data import MultiModalDataCollatorForSeq2Seq, get_dataset, get_template_and_fix_tokenizer
|
||||||
from llamafactory.extras.constants import IGNORE_INDEX
|
from llamafactory.extras.constants import IGNORE_INDEX
|
||||||
from llamafactory.hparams import get_train_args
|
from llamafactory.hparams import get_train_args
|
||||||
from llamafactory.model import load_tokenizer
|
from llamafactory.model import load_tokenizer
|
||||||
@ -71,7 +71,9 @@ def calculate_lr(
|
|||||||
if stage == "pt":
|
if stage == "pt":
|
||||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||||
elif stage == "sft":
|
elif stage == "sft":
|
||||||
data_collator = MultiModalDataCollatorForSeq2Seq(template=template, tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
data_collator = MultiModalDataCollatorForSeq2Seq(
|
||||||
|
template=template, tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Stage does not supported: {stage}.")
|
raise NotImplementedError(f"Stage does not supported: {stage}.")
|
||||||
|
|
||||||
|
@ -16,16 +16,25 @@ import json
|
|||||||
|
|
||||||
import fire
|
import fire
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
from vllm import LLM, SamplingParams
|
|
||||||
from vllm.lora.request import LoRARequest
|
|
||||||
|
|
||||||
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
|
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
|
||||||
from llamafactory.extras.constants import IGNORE_INDEX
|
from llamafactory.extras.constants import IGNORE_INDEX
|
||||||
from llamafactory.extras.misc import get_device_count
|
from llamafactory.extras.misc import get_device_count
|
||||||
|
from llamafactory.extras.packages import is_pillow_available, is_vllm_available
|
||||||
from llamafactory.hparams import get_infer_args
|
from llamafactory.hparams import get_infer_args
|
||||||
from llamafactory.model import load_tokenizer
|
from llamafactory.model import load_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
if is_pillow_available():
|
||||||
|
from PIL import Image
|
||||||
|
from PIL.Image import Image as ImageObject
|
||||||
|
|
||||||
|
|
||||||
|
if is_vllm_available():
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
|
|
||||||
def vllm_infer(
|
def vllm_infer(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
adapter_name_or_path: str = None,
|
adapter_name_or_path: str = None,
|
||||||
@ -64,15 +73,29 @@ def vllm_infer(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
training_args = Seq2SeqTrainingArguments(output_dir="dummy_dir", predict_with_generate=True)
|
training_args = Seq2SeqTrainingArguments(output_dir="dummy_dir")
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
tokenizer = tokenizer_module["tokenizer"]
|
tokenizer = tokenizer_module["tokenizer"]
|
||||||
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
template_obj = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||||
dataset = get_dataset(template, model_args, data_args, training_args, "ppo", **tokenizer_module)["train_dataset"]
|
template_obj.mm_plugin.expand_mm_tokens = False # for vllm generate
|
||||||
|
dataset_module = get_dataset(template_obj, model_args, data_args, training_args, "ppo", **tokenizer_module)
|
||||||
|
|
||||||
inputs, prompts, labels = [], [], []
|
inputs, prompts, labels = [], [], []
|
||||||
for sample in dataset:
|
for sample in dataset_module["train_dataset"]:
|
||||||
inputs.append({"prompt_token_ids": sample["input_ids"]})
|
if sample["images"]:
|
||||||
|
multi_modal_data = {"image": []}
|
||||||
|
for image in sample["images"]:
|
||||||
|
if not isinstance(image, (str, ImageObject)):
|
||||||
|
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
|
||||||
|
|
||||||
|
if isinstance(image, str):
|
||||||
|
image = Image.open(image).convert("RGB")
|
||||||
|
|
||||||
|
multi_modal_data["image"].append(image)
|
||||||
|
else:
|
||||||
|
multi_modal_data = None
|
||||||
|
|
||||||
|
inputs.append({"prompt_token_ids": sample["input_ids"], "multi_modal_data": multi_modal_data})
|
||||||
prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=False))
|
prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=False))
|
||||||
labels.append(
|
labels.append(
|
||||||
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=False)
|
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=False)
|
||||||
@ -100,6 +123,9 @@ def vllm_infer(
|
|||||||
"disable_log_stats": True,
|
"disable_log_stats": True,
|
||||||
"enable_lora": model_args.adapter_name_or_path is not None,
|
"enable_lora": model_args.adapter_name_or_path is not None,
|
||||||
}
|
}
|
||||||
|
if template_obj.mm_plugin.__class__.__name__ != "BasePlugin":
|
||||||
|
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2}
|
||||||
|
|
||||||
if isinstance(model_args.vllm_config, dict):
|
if isinstance(model_args.vllm_config, dict):
|
||||||
engine_args.update(model_args.vllm_config)
|
engine_args.update(model_args.vllm_config)
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ from typing_extensions import override
|
|||||||
|
|
||||||
from ..data import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.constants import IMAGE_PLACEHOLDER
|
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||||
from ..extras.misc import get_device_count
|
from ..extras.misc import get_device_count
|
||||||
from ..extras.packages import is_pillow_available, is_vllm_available
|
from ..extras.packages import is_pillow_available, is_vllm_available
|
||||||
from ..model import load_config, load_tokenizer
|
from ..model import load_config, load_tokenizer
|
||||||
@ -67,6 +67,7 @@ class VllmEngine(BaseEngine):
|
|||||||
self.processor = tokenizer_module["processor"]
|
self.processor = tokenizer_module["processor"]
|
||||||
self.tokenizer.padding_side = "left"
|
self.tokenizer.padding_side = "left"
|
||||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
|
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
|
||||||
|
self.template.mm_plugin.expand_mm_tokens = False # for vllm generate
|
||||||
self.generating_args = generating_args.to_dict()
|
self.generating_args = generating_args.to_dict()
|
||||||
|
|
||||||
engine_args = {
|
engine_args = {
|
||||||
@ -83,6 +84,9 @@ class VllmEngine(BaseEngine):
|
|||||||
"enable_lora": model_args.adapter_name_or_path is not None,
|
"enable_lora": model_args.adapter_name_or_path is not None,
|
||||||
"max_lora_rank": model_args.vllm_max_lora_rank,
|
"max_lora_rank": model_args.vllm_max_lora_rank,
|
||||||
}
|
}
|
||||||
|
if self.template.mm_plugin.__class__.__name__ != "BasePlugin":
|
||||||
|
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2}
|
||||||
|
|
||||||
if isinstance(model_args.vllm_config, dict):
|
if isinstance(model_args.vllm_config, dict):
|
||||||
engine_args.update(model_args.vllm_config)
|
engine_args.update(model_args.vllm_config)
|
||||||
|
|
||||||
@ -108,19 +112,21 @@ class VllmEngine(BaseEngine):
|
|||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> AsyncIterator["RequestOutput"]:
|
) -> AsyncIterator["RequestOutput"]:
|
||||||
request_id = f"chatcmpl-{uuid.uuid4().hex}"
|
request_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||||
|
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
|
||||||
if images is not None:
|
if images is not None:
|
||||||
|
mm_input_dict.update({"images": images, "imglens": [len(images)]})
|
||||||
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
|
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
|
||||||
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
|
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
|
||||||
|
|
||||||
if self.template.mm_plugin.__class__.__name__ == "Qwen2vlPlugin": # temporary solution
|
if videos is not None:
|
||||||
image_str = f"<|vision_start|>{self.template.mm_plugin.image_token}<|vision_end|>"
|
mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
|
||||||
else:
|
if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
|
||||||
image_str = self.template.mm_plugin.image_token or ""
|
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
|
||||||
|
|
||||||
paired_messages = [
|
messages = self.template.mm_plugin.process_messages(
|
||||||
{"role": message["role"], "content": message["content"].replace(IMAGE_PLACEHOLDER, image_str)}
|
messages, mm_input_dict["images"], mm_input_dict["videos"], self.processor
|
||||||
for message in messages
|
)
|
||||||
] + [{"role": "assistant", "content": ""}]
|
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||||
system = system or self.generating_args["default_system"]
|
system = system or self.generating_args["default_system"]
|
||||||
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
|
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
|
||||||
prompt_length = len(prompt_ids)
|
prompt_length = len(prompt_ids)
|
||||||
@ -168,7 +174,7 @@ class VllmEngine(BaseEngine):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if images is not None: # add image features
|
if images is not None: # add image features
|
||||||
image_data = []
|
multi_modal_data = {"image": []}
|
||||||
for image in images:
|
for image in images:
|
||||||
if not isinstance(image, (str, ImageObject)):
|
if not isinstance(image, (str, ImageObject)):
|
||||||
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
|
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
|
||||||
@ -176,9 +182,7 @@ class VllmEngine(BaseEngine):
|
|||||||
if isinstance(image, str):
|
if isinstance(image, str):
|
||||||
image = Image.open(image).convert("RGB")
|
image = Image.open(image).convert("RGB")
|
||||||
|
|
||||||
image_data.append(image)
|
multi_modal_data["image"].append(image)
|
||||||
|
|
||||||
multi_modal_data = {"image": image_data}
|
|
||||||
else:
|
else:
|
||||||
multi_modal_data = None
|
multi_modal_data = None
|
||||||
|
|
||||||
|
@ -62,6 +62,7 @@ class BasePlugin:
|
|||||||
def __init__(self, image_token: Optional[str], video_token: Optional[str]) -> None:
|
def __init__(self, image_token: Optional[str], video_token: Optional[str]) -> None:
|
||||||
self.image_token = image_token
|
self.image_token = image_token
|
||||||
self.video_token = video_token
|
self.video_token = video_token
|
||||||
|
self.expand_mm_tokens = True
|
||||||
|
|
||||||
def _validate_input(
|
def _validate_input(
|
||||||
self,
|
self,
|
||||||
@ -259,7 +260,7 @@ class LlavaPlugin(BasePlugin):
|
|||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
self._validate_input(images, videos)
|
self._validate_input(images, videos)
|
||||||
num_image_tokens = 0
|
num_image_tokens = 0
|
||||||
image_seqlen = getattr(processor, "image_seqlen")
|
image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1
|
||||||
messages = deepcopy(messages)
|
messages = deepcopy(messages)
|
||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
@ -310,11 +311,13 @@ class LlavaNextPlugin(BasePlugin):
|
|||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while IMAGE_PLACEHOLDER in content:
|
while IMAGE_PLACEHOLDER in content:
|
||||||
image_size = next(image_sizes)
|
if self.expand_mm_tokens:
|
||||||
orig_height, orig_width = image_size
|
orig_height, orig_width = next(image_sizes)
|
||||||
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||||
if getattr(processor, "vision_feature_select_strategy") == "default":
|
if getattr(processor, "vision_feature_select_strategy") == "default":
|
||||||
image_seqlen -= 1
|
image_seqlen -= 1
|
||||||
|
else:
|
||||||
|
image_seqlen = 1
|
||||||
|
|
||||||
num_image_tokens += 1
|
num_image_tokens += 1
|
||||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||||
@ -359,11 +362,13 @@ class LlavaNextVideoPlugin(BasePlugin):
|
|||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while IMAGE_PLACEHOLDER in content:
|
while IMAGE_PLACEHOLDER in content:
|
||||||
image_size = next(image_sizes)
|
if self.expand_mm_tokens:
|
||||||
orig_height, orig_width = image_size
|
orig_height, orig_width = next(image_sizes)
|
||||||
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||||
if getattr(processor, "vision_feature_select_strategy") == "default":
|
if getattr(processor, "vision_feature_select_strategy") == "default":
|
||||||
image_seqlen -= 1
|
image_seqlen -= 1
|
||||||
|
else:
|
||||||
|
image_seqlen = 1
|
||||||
|
|
||||||
num_image_tokens += 1
|
num_image_tokens += 1
|
||||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||||
@ -376,6 +381,7 @@ class LlavaNextVideoPlugin(BasePlugin):
|
|||||||
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
||||||
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
|
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
|
||||||
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
|
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
|
||||||
|
video_seqlen = video_seqlen if self.expand_mm_tokens else 1
|
||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while VIDEO_PLACEHOLDER in content:
|
while VIDEO_PLACEHOLDER in content:
|
||||||
@ -443,7 +449,7 @@ class PaliGemmaPlugin(BasePlugin):
|
|||||||
) -> Tuple[List[int], Optional[List[int]]]:
|
) -> Tuple[List[int], Optional[List[int]]]:
|
||||||
self._validate_input(images, videos)
|
self._validate_input(images, videos)
|
||||||
num_images = len(images)
|
num_images = len(images)
|
||||||
image_seqlen = num_images * getattr(processor, "image_seqlen")
|
image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token
|
||||||
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
||||||
input_ids = [image_token_id] * image_seqlen + input_ids
|
input_ids = [image_token_id] * image_seqlen + input_ids
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
@ -493,14 +499,18 @@ class PixtralPlugin(BasePlugin):
|
|||||||
if image_input_sizes is None:
|
if image_input_sizes is None:
|
||||||
raise ValueError("Cannot get image input sizes.")
|
raise ValueError("Cannot get image input sizes.")
|
||||||
|
|
||||||
image_size = image_input_sizes[0][num_image_tokens]
|
if self.expand_mm_tokens:
|
||||||
height, width = image_size
|
image_size = image_input_sizes[0][num_image_tokens]
|
||||||
num_height_tokens = height // patch_size
|
height, width = image_size
|
||||||
num_width_tokens = width // patch_size
|
num_height_tokens = height // patch_size
|
||||||
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
|
num_width_tokens = width // patch_size
|
||||||
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
|
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
|
||||||
replace_tokens[-1] = image_end_token
|
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
|
||||||
replace_str = "".join(replace_tokens)
|
replace_tokens[-1] = image_end_token
|
||||||
|
replace_str = "".join(replace_tokens)
|
||||||
|
else:
|
||||||
|
replace_str = image_token
|
||||||
|
|
||||||
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
|
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
|
||||||
num_image_tokens += 1
|
num_image_tokens += 1
|
||||||
|
|
||||||
@ -549,10 +559,27 @@ class Qwen2vlPlugin(BasePlugin):
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int:
|
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
|
||||||
sample_frames = super()._get_video_sample_frames(video_stream, **kwargs)
|
results = []
|
||||||
sample_frames = sample_frames // 2 * 2
|
for video in videos:
|
||||||
return sample_frames
|
container = av.open(video, "r")
|
||||||
|
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||||
|
total_frames = video_stream.frames
|
||||||
|
sample_frames = self._get_video_sample_frames(video_stream, **kwargs)
|
||||||
|
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
||||||
|
frames: List["ImageObject"] = []
|
||||||
|
container.seek(0)
|
||||||
|
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||||
|
if frame_idx in sample_indices:
|
||||||
|
frames.append(frame.to_image())
|
||||||
|
|
||||||
|
if len(frames) % 2 != 0: # qwen2-vl requires even number of frames
|
||||||
|
frames.append(frames[-1])
|
||||||
|
|
||||||
|
frames = self._regularize_images(frames, **kwargs)
|
||||||
|
results.append(frames)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def process_messages(
|
def process_messages(
|
||||||
@ -577,12 +604,9 @@ class Qwen2vlPlugin(BasePlugin):
|
|||||||
if num_image_tokens >= len(image_grid_thw):
|
if num_image_tokens >= len(image_grid_thw):
|
||||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||||
content = content.replace(
|
content = content.replace(
|
||||||
IMAGE_PLACEHOLDER,
|
IMAGE_PLACEHOLDER, f"<|vision_start|>{self.image_token * image_seqlen}<|vision_end|>", 1
|
||||||
"<|vision_start|>{}<|vision_end|>".format(
|
|
||||||
self.image_token * (image_grid_thw[num_image_tokens].prod() // merge_length)
|
|
||||||
),
|
|
||||||
1,
|
|
||||||
)
|
)
|
||||||
num_image_tokens += 1
|
num_image_tokens += 1
|
||||||
|
|
||||||
@ -590,12 +614,9 @@ class Qwen2vlPlugin(BasePlugin):
|
|||||||
if num_video_tokens >= len(video_grid_thw):
|
if num_video_tokens >= len(video_grid_thw):
|
||||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||||
content = content.replace(
|
content = content.replace(
|
||||||
VIDEO_PLACEHOLDER,
|
VIDEO_PLACEHOLDER, f"<|vision_start|>{self.video_token * video_seqlen}<|vision_end|>", 1
|
||||||
"<|vision_start|>{}<|vision_end|>".format(
|
|
||||||
self.video_token * (video_grid_thw[num_video_tokens].prod() // merge_length)
|
|
||||||
),
|
|
||||||
1,
|
|
||||||
)
|
)
|
||||||
num_video_tokens += 1
|
num_video_tokens += 1
|
||||||
|
|
||||||
@ -640,19 +661,22 @@ class VideoLlavaPlugin(BasePlugin):
|
|||||||
has_images = "pixel_values_images" in mm_inputs
|
has_images = "pixel_values_images" in mm_inputs
|
||||||
has_videos = "pixel_values_videos" in mm_inputs
|
has_videos = "pixel_values_videos" in mm_inputs
|
||||||
if has_images or has_videos:
|
if has_images or has_videos:
|
||||||
if has_images:
|
if self.expand_mm_tokens:
|
||||||
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
|
if has_images:
|
||||||
num_frames = 1
|
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
|
||||||
|
num_frames = 1
|
||||||
|
|
||||||
if has_videos:
|
if has_videos:
|
||||||
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
||||||
height, width = get_image_size(pixel_values_video[0])
|
height, width = get_image_size(pixel_values_video[0])
|
||||||
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
||||||
|
|
||||||
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
|
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
|
||||||
video_seqlen = image_seqlen * num_frames
|
video_seqlen = image_seqlen * num_frames
|
||||||
if getattr(processor, "vision_feature_select_strategy") == "default":
|
if getattr(processor, "vision_feature_select_strategy") == "default":
|
||||||
image_seqlen -= 1
|
image_seqlen -= 1
|
||||||
|
else:
|
||||||
|
image_seqlen, video_seqlen = 1, 1
|
||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user