mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 11:12:50 +08:00
video datasets
Former-commit-id: 8cafc7b055a854f483ad1c67f3d487ffd34b5f89
This commit is contained in:
parent
60d770e4b1
commit
9df7a26e6b
@ -38,6 +38,20 @@
|
||||
"assistant_tag": "assistant"
|
||||
}
|
||||
},
|
||||
"mllm_video_demo": {
|
||||
"file_name": "mllm_video_demo.json",
|
||||
"formatting": "sharegpt",
|
||||
"columns": {
|
||||
"messages": "messages",
|
||||
"videos": "videos"
|
||||
},
|
||||
"tags": {
|
||||
"role_tag": "role",
|
||||
"content_tag": "content",
|
||||
"user_tag": "user",
|
||||
"assistant_tag": "assistant"
|
||||
}
|
||||
},
|
||||
"alpaca_en": {
|
||||
"hf_hub_url": "llamafactory/alpaca_en",
|
||||
"ms_hub_url": "llamafactory/alpaca_en"
|
||||
|
BIN
data/mllm_demo_data/1.mp4
Normal file
BIN
data/mllm_demo_data/1.mp4
Normal file
Binary file not shown.
BIN
data/mllm_demo_data/2.avi
Normal file
BIN
data/mllm_demo_data/2.avi
Normal file
Binary file not shown.
BIN
data/mllm_demo_data/3.mp4
Normal file
BIN
data/mllm_demo_data/3.mp4
Normal file
Binary file not shown.
47
data/mllm_video_demo.json
Normal file
47
data/mllm_video_demo.json
Normal file
@ -0,0 +1,47 @@
|
||||
[
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "<video>Why is this video funny?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "Because a baby is reading, and he is so cute!",
|
||||
"role": "assistant"
|
||||
}
|
||||
],
|
||||
"videos": [
|
||||
"mllm_demo_data/1.mp4"
|
||||
]
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "<video>What is she doing?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "She is cooking.",
|
||||
"role": "assistant"
|
||||
}
|
||||
],
|
||||
"videos": [
|
||||
"mllm_demo_data/2.avi"
|
||||
]
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "<video>What's in the video?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "A baby is playing in the living room.",
|
||||
"role": "assistant"
|
||||
}
|
||||
],
|
||||
"videos": [
|
||||
"mllm_demo_data/3.mp4"
|
||||
]
|
||||
}
|
||||
]
|
@ -18,11 +18,11 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Opti
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
from vllm import AsyncLLMEngine
|
||||
|
||||
from ..data import Template
|
||||
from ..data.mm_plugin import ImageInput, VideoInput
|
||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
@ -56,7 +56,8 @@ class BaseEngine(ABC):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]: ...
|
||||
|
||||
@ -66,7 +67,8 @@ class BaseEngine(ABC):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]: ...
|
||||
|
||||
|
@ -27,8 +27,7 @@ from .vllm_engine import VllmEngine
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from PIL.Image import Image
|
||||
|
||||
from ..data.mm_plugin import ImageInput, VideoInput
|
||||
from .base_engine import BaseEngine, Response
|
||||
|
||||
|
||||
@ -56,10 +55,13 @@ class ChatModel:
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["Image"] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop)
|
||||
task = asyncio.run_coroutine_threadsafe(
|
||||
self.achat(messages, system, tools, image, video, **input_kwargs), self._loop
|
||||
)
|
||||
return task.result()
|
||||
|
||||
async def achat(
|
||||
@ -67,20 +69,22 @@ class ChatModel:
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["Image"] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
return await self.engine.chat(messages, system, tools, image, **input_kwargs)
|
||||
return await self.engine.chat(messages, system, tools, image, video, **input_kwargs)
|
||||
|
||||
def stream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["Image"] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
**input_kwargs,
|
||||
) -> Generator[str, None, None]:
|
||||
generator = self.astream_chat(messages, system, tools, image, **input_kwargs)
|
||||
generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs)
|
||||
while True:
|
||||
try:
|
||||
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
|
||||
@ -93,10 +97,11 @@ class ChatModel:
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["Image"] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs):
|
||||
async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs):
|
||||
yield new_token
|
||||
|
||||
def get_scores(
|
||||
|
@ -22,7 +22,7 @@ import torch
|
||||
from transformers import GenerationConfig, TextIteratorStreamer
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.constants import IMAGE_PLACEHOLDER
|
||||
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import get_logits_processor
|
||||
from ..model import load_model, load_tokenizer
|
||||
@ -30,11 +30,11 @@ from .base_engine import BaseEngine, Response
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from PIL.Image import Image
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||
from trl import PreTrainedModelWrapper
|
||||
|
||||
from ..data import Template
|
||||
from ..data.mm_plugin import ImageInput, VideoInput
|
||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
@ -78,20 +78,30 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["Image"] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
|
||||
if image is not None:
|
||||
mm_input_dict.update({"images": [image], "imglens": [1]})
|
||||
if IMAGE_PLACEHOLDER not in messages[0]["content"]:
|
||||
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
|
||||
|
||||
messages = template.mm_plugin.process_messages(messages, [image], processor)
|
||||
if video is not None:
|
||||
mm_input_dict.update({"videos": [video], "vidlens": [1]})
|
||||
if VIDEO_PLACEHOLDER not in messages[0]["content"]:
|
||||
messages[0]["content"] = VIDEO_PLACEHOLDER + messages[0]["content"]
|
||||
|
||||
messages = template.mm_plugin.process_messages(
|
||||
messages, mm_input_dict["images"], mm_input_dict["videos"], processor
|
||||
)
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
system = system or generating_args["default_system"]
|
||||
prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
|
||||
if image is not None:
|
||||
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, [image], tokenizer, processor)
|
||||
prompt_ids, _ = template.mm_plugin.process_token_ids(
|
||||
prompt_ids, None, mm_input_dict["images"], mm_input_dict["videos"], tokenizer, processor
|
||||
)
|
||||
|
||||
prompt_length = len(prompt_ids)
|
||||
inputs = torch.tensor([prompt_ids], device=model.device)
|
||||
@ -154,13 +164,10 @@ class HuggingfaceEngine(BaseEngine):
|
||||
logits_processor=get_logits_processor(),
|
||||
)
|
||||
|
||||
if image is not None:
|
||||
mm_inputs = template.mm_plugin.get_mm_inputs(
|
||||
images=[image], imglens=[1], seqlens=[prompt_length], processor=processor
|
||||
)
|
||||
for key, value in mm_inputs.items():
|
||||
value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
|
||||
gen_kwargs[key] = value.to(model.device)
|
||||
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
|
||||
for key, value in mm_inputs.items():
|
||||
value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
|
||||
gen_kwargs[key] = value.to(model.device)
|
||||
|
||||
return gen_kwargs, prompt_length
|
||||
|
||||
@ -175,11 +182,12 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["Image"] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> List["Response"]:
|
||||
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
|
||||
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
|
||||
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
|
||||
)
|
||||
generate_output = model.generate(**gen_kwargs)
|
||||
response_ids = generate_output[:, prompt_length:]
|
||||
@ -210,11 +218,12 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["Image"] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Callable[[], str]:
|
||||
gen_kwargs, _ = HuggingfaceEngine._process_args(
|
||||
model, tokenizer, processor, template, generating_args, messages, system, tools, image, input_kwargs
|
||||
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
|
||||
)
|
||||
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||
gen_kwargs["streamer"] = streamer
|
||||
@ -267,7 +276,8 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["Image"] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
if not self.can_generate:
|
||||
@ -284,6 +294,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
system,
|
||||
tools,
|
||||
image,
|
||||
video,
|
||||
input_kwargs,
|
||||
)
|
||||
async with self.semaphore:
|
||||
@ -295,7 +306,8 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["Image"] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
if not self.can_generate:
|
||||
@ -312,6 +324,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
system,
|
||||
tools,
|
||||
image,
|
||||
video,
|
||||
input_kwargs,
|
||||
)
|
||||
async with self.semaphore:
|
||||
|
@ -25,7 +25,7 @@ if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
|
||||
from ..hparams import DataArguments
|
||||
from .mm_plugin import ImageInput
|
||||
from .mm_plugin import ImageInput, VideoInput
|
||||
from .parser import DatasetAttr
|
||||
|
||||
|
||||
@ -52,6 +52,26 @@ def _convert_images(
|
||||
return images
|
||||
|
||||
|
||||
def _convert_videos(
|
||||
videos: Sequence["VideoInput"],
|
||||
dataset_attr: "DatasetAttr",
|
||||
data_args: "DataArguments",
|
||||
) -> Optional[List["VideoInput"]]:
|
||||
r"""
|
||||
Optionally concatenates video path to dataset dir when loading from local disk.
|
||||
"""
|
||||
if len(videos) == 0:
|
||||
return None
|
||||
|
||||
videos = videos[:]
|
||||
if dataset_attr.load_from in ["script", "file"]:
|
||||
for i in range(len(videos)):
|
||||
if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, videos[i])):
|
||||
videos[i] = os.path.join(data_args.dataset_dir, videos[i])
|
||||
|
||||
return videos
|
||||
|
||||
|
||||
def convert_alpaca(
|
||||
example: Dict[str, Any],
|
||||
dataset_attr: "DatasetAttr",
|
||||
@ -96,12 +116,14 @@ def convert_alpaca(
|
||||
response = []
|
||||
|
||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||
convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
|
||||
output = {
|
||||
"_prompt": prompt,
|
||||
"_response": response,
|
||||
"_system": example[dataset_attr.system] if dataset_attr.system else "",
|
||||
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
|
||||
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
|
||||
"_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
|
||||
}
|
||||
return output
|
||||
|
||||
@ -187,12 +209,14 @@ def convert_sharegpt(
|
||||
prompt, response = [], []
|
||||
|
||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||
convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
|
||||
output = {
|
||||
"_prompt": prompt,
|
||||
"_response": response,
|
||||
"_system": system,
|
||||
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
|
||||
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
|
||||
"_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
|
||||
}
|
||||
return output
|
||||
|
||||
@ -210,6 +234,7 @@ def align_dataset(
|
||||
_system: "..."
|
||||
_tools: "...",
|
||||
_images: [],
|
||||
_videos: [],
|
||||
"""
|
||||
if dataset_attr.formatting == "alpaca":
|
||||
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
|
||||
|
@ -79,14 +79,19 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
processor: Optional["ProcessorMixin"] = None
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
batch_images, batch_imglens, batch_seqlens = [], [], []
|
||||
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens = [], [], [], [], []
|
||||
for feature in features:
|
||||
images = feature.pop("images") or [] # avoid NoneType
|
||||
videos = feature.pop("videos") or []
|
||||
batch_images.extend(images)
|
||||
batch_videos.extend(videos)
|
||||
batch_imglens.append(len(images))
|
||||
batch_vidlens.append(len(videos))
|
||||
batch_seqlens.append(len(feature["input_ids"]))
|
||||
|
||||
mm_inputs = self.template.mm_plugin.get_mm_inputs(batch_images, batch_imglens, batch_seqlens, self.processor)
|
||||
mm_inputs = self.template.mm_plugin.get_mm_inputs(
|
||||
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens, self.processor
|
||||
)
|
||||
if "token_type_ids" in mm_inputs:
|
||||
token_type_ids = mm_inputs.pop("token_type_ids")
|
||||
for i, feature in enumerate(features):
|
||||
@ -136,6 +141,7 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
"attention_mask": feature["{}_attention_mask".format(key)],
|
||||
"labels": feature["{}_labels".format(key)],
|
||||
"images": feature["images"],
|
||||
"videos": feature["videos"],
|
||||
}
|
||||
concatenated_features.append(target_feature)
|
||||
|
||||
@ -158,12 +164,14 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
"attention_mask": feature["attention_mask"],
|
||||
"labels": feature["labels"],
|
||||
"images": feature["images"],
|
||||
"videos": feature["videos"],
|
||||
}
|
||||
kl_feature = {
|
||||
"input_ids": feature["kl_input_ids"],
|
||||
"attention_mask": feature["kl_attention_mask"],
|
||||
"labels": feature["kl_labels"],
|
||||
"images": feature["images"],
|
||||
"videos": feature["videos"],
|
||||
}
|
||||
target_features.append(target_feature)
|
||||
kl_features.append(kl_feature)
|
||||
|
@ -2,11 +2,10 @@ from copy import deepcopy
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
||||
|
||||
from PIL.Image import Image
|
||||
from transformers import ProcessorMixin
|
||||
import numpy as np
|
||||
|
||||
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER
|
||||
from ..extras.packages import is_pillow_available
|
||||
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
from ..extras.packages import is_pillow_available, is_pyav_available
|
||||
|
||||
|
||||
if is_pillow_available():
|
||||
@ -14,8 +13,13 @@ if is_pillow_available():
|
||||
from PIL.Image import Image as ImageObject
|
||||
|
||||
|
||||
if is_pyav_available():
|
||||
import av
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
from numpy.typing import NDArray
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
@ -24,13 +28,14 @@ if TYPE_CHECKING:
|
||||
bytes: Optional[bytes]
|
||||
|
||||
ImageInput = Union[str, EncodedImage, ImageObject]
|
||||
VideoInput = str
|
||||
|
||||
|
||||
def _regularize_images(images: Sequence["ImageInput"], processor: "ProcessorMixin") -> List["ImageObject"]:
|
||||
r"""
|
||||
Regularizes images to avoid error. Including reading, resizing and converting.
|
||||
"""
|
||||
image_resolution = getattr(processor, "image_resolution", 512)
|
||||
image_resolution: int = getattr(processor, "image_resolution", 512)
|
||||
results = []
|
||||
for image in images:
|
||||
if isinstance(image, str):
|
||||
@ -56,7 +61,37 @@ def _regularize_images(images: Sequence["ImageInput"], processor: "ProcessorMixi
|
||||
return results
|
||||
|
||||
|
||||
def _get_mm_inputs(images: Sequence["ImageInput"], processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
|
||||
def _regularize_videos(videos: Sequence["VideoInput"], processor: "ProcessorMixin") -> List["NDArray"]:
|
||||
r"""
|
||||
Regularizes videos to avoid error. Including reading, resizing and converting.
|
||||
"""
|
||||
video_fps: float = getattr(processor, "video_fps", 1.0)
|
||||
video_factor: int = getattr(processor, "video_factor", 1)
|
||||
results = []
|
||||
for video in videos:
|
||||
container = av.open(video, "r")
|
||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||
total_frames = video_stream.frames
|
||||
sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps
|
||||
sample_frames = round(sample_frames / video_factor) * video_factor # for qwen2_vl
|
||||
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
||||
frames: List["ImageObject"] = []
|
||||
container.seek(0)
|
||||
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||
if frame_idx in sample_indices:
|
||||
frames.append(frame.to_image())
|
||||
|
||||
frames = _regularize_images(frames, processor)
|
||||
results.append(frames)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _get_mm_inputs(
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: "ProcessorMixin",
|
||||
) -> Dict[str, "torch.Tensor"]:
|
||||
r"""
|
||||
Processes visual inputs.
|
||||
|
||||
@ -70,13 +105,19 @@ def _get_mm_inputs(images: Sequence["ImageInput"], processor: "ProcessorMixin")
|
||||
It holds num_patches == torch.prod(image_grid_thw)
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
input_dict = {"images": None, "videos": None}
|
||||
if len(images) != 0:
|
||||
images = _regularize_images(images, processor)
|
||||
image_inputs = image_processor(images=images, return_tensors="pt")
|
||||
else:
|
||||
image_inputs = {}
|
||||
input_dict["images"] = images
|
||||
|
||||
return image_inputs
|
||||
if len(videos) != 0:
|
||||
videos = _regularize_videos(videos, processor)
|
||||
input_dict["videos"] = videos
|
||||
|
||||
if input_dict["images"] is not None or input_dict["videos"] is not None:
|
||||
return image_processor(**input_dict, return_tensors="pt")
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
def _get_paligemma_token_type_ids(
|
||||
@ -97,18 +138,32 @@ def _get_paligemma_token_type_ids(
|
||||
|
||||
|
||||
class BasePlugin:
|
||||
def __init__(self, image_token: str) -> None:
|
||||
def __init__(self, image_token: Optional[str], video_token: Optional[str]) -> None:
|
||||
self.image_token = image_token
|
||||
self.video_token = video_token
|
||||
|
||||
def _validate_input(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
) -> None:
|
||||
if len(images) != 0 and self.image_token is None:
|
||||
raise ValueError("This model does not support image input.")
|
||||
|
||||
if len(videos) != 0 and self.video_token is None:
|
||||
raise ValueError("This model does not support video input.")
|
||||
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
r"""
|
||||
Pre-processes input messages before tokenization for VLMs.
|
||||
"""
|
||||
self._validate_input(images, videos)
|
||||
return messages
|
||||
|
||||
def process_token_ids(
|
||||
@ -116,24 +171,29 @@ class BasePlugin:
|
||||
input_ids: List[int],
|
||||
labels: Optional[List[int]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Tuple[List[int], Optional[List[int]]]:
|
||||
r"""
|
||||
Pre-processes token ids after tokenization for VLMs.
|
||||
"""
|
||||
self._validate_input(images, videos)
|
||||
return input_ids, labels
|
||||
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
r"""
|
||||
Builds batched multimodal inputs for VLMs.
|
||||
"""
|
||||
self._validate_input(images, videos)
|
||||
return {}
|
||||
|
||||
|
||||
@ -142,8 +202,10 @@ class LlavaPlugin(BasePlugin):
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
num_image_tokens = 0
|
||||
image_seqlen = getattr(processor, "image_seqlen")
|
||||
messages = deepcopy(messages)
|
||||
@ -163,11 +225,14 @@ class LlavaPlugin(BasePlugin):
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
return _get_mm_inputs(images, processor)
|
||||
self._validate_input(images, videos)
|
||||
return _get_mm_inputs(images, videos, processor)
|
||||
|
||||
|
||||
class PaliGemmaPlugin(BasePlugin):
|
||||
@ -175,8 +240,10 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
@ -197,9 +264,11 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
input_ids: List[int],
|
||||
labels: Optional[List[int]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Tuple[List[int], Optional[List[int]]]:
|
||||
self._validate_input(images, videos)
|
||||
num_images = len(images)
|
||||
image_seqlen = num_images * getattr(processor, "image_seqlen")
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
||||
@ -212,11 +281,14 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
mm_inputs = _get_mm_inputs(images, processor)
|
||||
self._validate_input(images, videos)
|
||||
mm_inputs = _get_mm_inputs(images, videos, processor)
|
||||
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
|
||||
return mm_inputs
|
||||
|
||||
@ -226,16 +298,17 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
||||
if len(images) != 0:
|
||||
image_grid_thw = _get_mm_inputs(images, processor)["image_grid_thw"]
|
||||
else:
|
||||
image_grid_thw = []
|
||||
mm_inputs = _get_mm_inputs(images, videos, processor)
|
||||
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
||||
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
||||
|
||||
num_image_tokens = 0
|
||||
num_image_tokens, num_video_tokens = 0, 0
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
@ -252,21 +325,40 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
)
|
||||
num_image_tokens += 1
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
if num_video_tokens >= len(video_grid_thw):
|
||||
raise ValueError("`len(videos)` is less than the number of {} tokens.".format(VIDEO_PLACEHOLDER))
|
||||
|
||||
content = content.replace(
|
||||
VIDEO_PLACEHOLDER,
|
||||
"<|vision_start|>{}<|vision_end|>".format(
|
||||
self.video_token * (video_grid_thw[num_video_tokens].prod() // merge_length)
|
||||
),
|
||||
1,
|
||||
)
|
||||
num_video_tokens += 1
|
||||
|
||||
message["content"] = content
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError("The number of videos does not match the number of {} tokens".format(VIDEO_PLACEHOLDER))
|
||||
|
||||
return messages
|
||||
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
return _get_mm_inputs(images, processor)
|
||||
self._validate_input(images, videos)
|
||||
return _get_mm_inputs(images, videos, processor)
|
||||
|
||||
|
||||
PLUGINS = {
|
||||
@ -277,9 +369,13 @@ PLUGINS = {
|
||||
}
|
||||
|
||||
|
||||
def get_mm_plugin(name: str, image_token: str) -> "BasePlugin":
|
||||
def get_mm_plugin(
|
||||
name: str,
|
||||
image_token: Optional[str] = None,
|
||||
video_token: Optional[str] = None,
|
||||
) -> "BasePlugin":
|
||||
plugin_class = PLUGINS.get(name, None)
|
||||
if plugin_class is None:
|
||||
raise ValueError("Multimodal plugin `{}` not found.".format(name))
|
||||
|
||||
return plugin_class(image_token)
|
||||
return plugin_class(image_token, video_token)
|
||||
|
@ -43,6 +43,7 @@ class DatasetAttr:
|
||||
system: Optional[str] = None
|
||||
tools: Optional[str] = None
|
||||
images: Optional[str] = None
|
||||
videos: Optional[str] = None
|
||||
# rlhf columns
|
||||
chosen: Optional[str] = None
|
||||
rejected: Optional[str] = None
|
||||
@ -126,7 +127,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
|
||||
dataset_attr.set_attr("num_samples", dataset_info[name])
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]
|
||||
column_names = ["system", "tools", "images", "videos", "chosen", "rejected", "kto_tag"]
|
||||
if dataset_attr.formatting == "alpaca":
|
||||
column_names.extend(["prompt", "query", "response", "history"])
|
||||
else:
|
||||
|
@ -24,7 +24,7 @@ if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..mm_plugin import ImageInput
|
||||
from ..mm_plugin import ImageInput, VideoInput
|
||||
from ..template import Template
|
||||
|
||||
|
||||
@ -38,6 +38,7 @@ def _encode_feedback_example(
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
@ -55,8 +56,8 @@ def _encode_feedback_example(
|
||||
else:
|
||||
kl_messages = prompt + [kl_response[1]]
|
||||
|
||||
messages = template.mm_plugin.process_messages(messages, images, processor)
|
||||
kl_messages = template.mm_plugin.process_messages(kl_messages, images, processor)
|
||||
messages = template.mm_plugin.process_messages(messages, images, videos, processor)
|
||||
kl_messages = template.mm_plugin.process_messages(kl_messages, images, videos, processor)
|
||||
prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
|
||||
kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
|
||||
|
||||
@ -64,8 +65,8 @@ def _encode_feedback_example(
|
||||
response_ids += [tokenizer.eos_token_id]
|
||||
kl_response_ids += [tokenizer.eos_token_id]
|
||||
|
||||
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, tokenizer, processor)
|
||||
kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, images, tokenizer, processor)
|
||||
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, videos, tokenizer, processor)
|
||||
kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, images, videos, tokenizer, processor)
|
||||
|
||||
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len)
|
||||
prompt_ids = prompt_ids[:source_len]
|
||||
@ -103,6 +104,7 @@ def preprocess_feedback_dataset(
|
||||
system=examples["_system"][i],
|
||||
tools=examples["_tools"][i],
|
||||
images=examples["_images"][i] or [],
|
||||
videos=examples["_videos"][i] or [],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
@ -116,6 +118,7 @@ def preprocess_feedback_dataset(
|
||||
model_inputs["kl_labels"].append(kl_labels)
|
||||
model_inputs["kto_tags"].append(kto_tag)
|
||||
model_inputs["images"].append(examples["_images"][i])
|
||||
model_inputs["videos"].append(examples["_videos"][i])
|
||||
|
||||
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
|
||||
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
|
||||
|
@ -24,7 +24,7 @@ if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..mm_plugin import ImageInput
|
||||
from ..mm_plugin import ImageInput, VideoInput
|
||||
from ..template import Template
|
||||
|
||||
|
||||
@ -37,13 +37,14 @@ def _encode_pairwise_example(
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
cutoff_len: int,
|
||||
) -> Tuple[List[int], List[int], List[int], List[int]]:
|
||||
chosen_messages = template.mm_plugin.process_messages(prompt + [response[0]], images, processor)
|
||||
rejected_messages = template.mm_plugin.process_messages(prompt + [response[1]], images, processor)
|
||||
chosen_messages = template.mm_plugin.process_messages(prompt + [response[0]], images, videos, processor)
|
||||
rejected_messages = template.mm_plugin.process_messages(prompt + [response[1]], images, videos, processor)
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
|
||||
_, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
|
||||
|
||||
@ -51,7 +52,7 @@ def _encode_pairwise_example(
|
||||
chosen_ids += [tokenizer.eos_token_id]
|
||||
rejected_ids += [tokenizer.eos_token_id]
|
||||
|
||||
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, tokenizer, processor)
|
||||
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, videos, tokenizer, processor)
|
||||
# consider the response is more important
|
||||
source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len)
|
||||
prompt_ids = prompt_ids[:source_len]
|
||||
@ -85,6 +86,7 @@ def preprocess_pairwise_dataset(
|
||||
system=examples["_system"][i],
|
||||
tools=examples["_tools"][i],
|
||||
images=examples["_images"][i] or [],
|
||||
videos=examples["_videos"][i] or [],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
@ -97,6 +99,7 @@ def preprocess_pairwise_dataset(
|
||||
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
|
||||
model_inputs["rejected_labels"].append(rejected_labels)
|
||||
model_inputs["images"].append(examples["_images"][i])
|
||||
model_inputs["videos"].append(examples["_videos"][i])
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
@ -24,7 +24,7 @@ if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..mm_plugin import ImageInput
|
||||
from ..mm_plugin import ImageInput, VideoInput
|
||||
from ..template import Template
|
||||
|
||||
|
||||
@ -37,6 +37,7 @@ def _encode_supervised_example(
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
@ -44,8 +45,8 @@ def _encode_supervised_example(
|
||||
train_on_prompt: bool,
|
||||
mask_history: bool,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
messages = template.mm_plugin.process_messages(prompt + response, images, processor)
|
||||
input_ids, labels = template.mm_plugin.process_token_ids([], [], images, tokenizer, processor)
|
||||
messages = template.mm_plugin.process_messages(prompt + response, images, videos, processor)
|
||||
input_ids, labels = template.mm_plugin.process_token_ids([], [], images, videos, tokenizer, processor)
|
||||
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
|
||||
total_length = len(input_ids) + (1 if template.efficient_eos else 0)
|
||||
if mask_history:
|
||||
@ -107,6 +108,7 @@ def preprocess_supervised_dataset(
|
||||
system=examples["_system"][i],
|
||||
tools=examples["_tools"][i],
|
||||
images=examples["_images"][i] or [],
|
||||
videos=examples["_videos"][i] or [],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
@ -118,6 +120,7 @@ def preprocess_supervised_dataset(
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
model_inputs["images"].append(examples["_images"][i])
|
||||
model_inputs["videos"].append(examples["_videos"][i])
|
||||
|
||||
return model_inputs
|
||||
|
||||
@ -132,11 +135,8 @@ def preprocess_packed_supervised_dataset(
|
||||
# TODO: use `position_ids` to achieve packing
|
||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
||||
if processor is not None:
|
||||
raise NotImplementedError("`packing` have not been implemented for multimodal datasets.")
|
||||
|
||||
valid_num = 0
|
||||
batch_input_ids, batch_labels = [], []
|
||||
batch_input_ids, batch_labels, batch_images, batch_videos = [], [], [], []
|
||||
lengths = []
|
||||
length2indexes = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
@ -150,9 +150,10 @@ def preprocess_packed_supervised_dataset(
|
||||
system=examples["_system"][i],
|
||||
tools=examples["_tools"][i],
|
||||
images=examples["_images"][i] or [],
|
||||
videos=examples["_videos"][i] or [],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=None,
|
||||
processor=processor,
|
||||
cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token
|
||||
train_on_prompt=data_args.train_on_prompt,
|
||||
mask_history=data_args.mask_history,
|
||||
@ -165,16 +166,21 @@ def preprocess_packed_supervised_dataset(
|
||||
length2indexes[length].append(valid_num)
|
||||
batch_input_ids.append(input_ids)
|
||||
batch_labels.append(labels)
|
||||
batch_images.append(examples["_images"][i] or [])
|
||||
batch_videos.append(examples["_videos"][i] or [])
|
||||
valid_num += 1
|
||||
|
||||
model_inputs = defaultdict(list)
|
||||
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token
|
||||
for knapsack in knapsacks:
|
||||
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
|
||||
packed_images, packed_videos = [], []
|
||||
for i, length in enumerate(knapsack):
|
||||
index = length2indexes[length].pop()
|
||||
packed_input_ids += batch_input_ids[index]
|
||||
packed_labels += batch_labels[index]
|
||||
packed_images += batch_images[index]
|
||||
packed_videos += batch_videos[index]
|
||||
if data_args.neat_packing:
|
||||
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
|
||||
else:
|
||||
@ -195,7 +201,8 @@ def preprocess_packed_supervised_dataset(
|
||||
model_inputs["input_ids"].append(packed_input_ids)
|
||||
model_inputs["attention_mask"].append(packed_attention_masks)
|
||||
model_inputs["labels"].append(packed_labels)
|
||||
model_inputs["images"].append(examples["_images"][i])
|
||||
model_inputs["images"].append(packed_images or None)
|
||||
model_inputs["videos"].append(packed_videos or None)
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
@ -24,7 +24,7 @@ if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
from ..mm_plugin import ImageInput
|
||||
from ..mm_plugin import ImageInput, VideoInput
|
||||
from ..template import Template
|
||||
|
||||
|
||||
@ -37,6 +37,7 @@ def _encode_unsupervised_example(
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
@ -47,12 +48,12 @@ def _encode_unsupervised_example(
|
||||
else:
|
||||
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
|
||||
messages = template.mm_plugin.process_messages(messages, images, processor)
|
||||
messages = template.mm_plugin.process_messages(messages, images, videos, processor)
|
||||
input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
|
||||
if template.efficient_eos:
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, images, tokenizer, processor)
|
||||
input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, images, videos, tokenizer, processor)
|
||||
source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len)
|
||||
input_ids = input_ids[:source_len]
|
||||
labels = labels[:target_len]
|
||||
@ -79,6 +80,7 @@ def preprocess_unsupervised_dataset(
|
||||
system=examples["_system"][i],
|
||||
tools=examples["_tools"][i],
|
||||
images=examples["_images"][i] or [],
|
||||
videos=examples["_videos"][i] or [],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
@ -88,6 +90,7 @@ def preprocess_unsupervised_dataset(
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
model_inputs["images"].append(examples["_images"][i])
|
||||
model_inputs["videos"].append(examples["_videos"][i])
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
@ -17,7 +17,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras.constants import IMAGE_PLACEHOLDER
|
||||
from ..extras.logging import get_logger
|
||||
from .data_utils import Role
|
||||
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
||||
@ -213,7 +212,7 @@ def _register_template(
|
||||
stop_words: Sequence[str] = [],
|
||||
efficient_eos: bool = False,
|
||||
replace_eos: bool = False,
|
||||
mm_plugin: "BasePlugin" = get_mm_plugin(name="base", image_token=IMAGE_PLACEHOLDER),
|
||||
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
|
||||
) -> None:
|
||||
r"""
|
||||
Registers a chat template.
|
||||
@ -826,7 +825,7 @@ _register_template(
|
||||
default_system="You are a helpful assistant.",
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>"),
|
||||
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
|
||||
)
|
||||
|
||||
|
||||
|
@ -95,6 +95,8 @@ SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
|
||||
|
||||
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
|
||||
|
||||
VIDEO_PLACEHOLDER = "<video>"
|
||||
|
||||
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
||||
|
||||
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
|
||||
|
@ -38,6 +38,10 @@ def _get_package_version(name: str) -> "Version":
|
||||
return version.parse("0.0.0")
|
||||
|
||||
|
||||
def is_pyav_available():
|
||||
return _is_package_available("av")
|
||||
|
||||
|
||||
def is_fastapi_available():
|
||||
return _is_package_available("fastapi")
|
||||
|
||||
|
@ -142,6 +142,10 @@ class ModelArguments:
|
||||
default=512,
|
||||
metadata={"help": "Keeps the height or width of image below this resolution."},
|
||||
)
|
||||
video_fps: float = field(
|
||||
default=2.0,
|
||||
metadata={"help": "The frames to sample per second for video training."},
|
||||
)
|
||||
infer_backend: Literal["huggingface", "vllm"] = field(
|
||||
default="huggingface",
|
||||
metadata={"help": "Backend engine used at inference."},
|
||||
|
@ -100,6 +100,11 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
setattr(processor, "tokenizer", tokenizer)
|
||||
setattr(processor, "image_seqlen", get_image_seqlen(config))
|
||||
setattr(processor, "image_resolution", model_args.image_resolution)
|
||||
setattr(processor, "video_fps", model_args.video_fps)
|
||||
if getattr(config, "model_type", None) == "qwen2_vl":
|
||||
setattr(processor, "video_factor", 2)
|
||||
else:
|
||||
setattr(processor, "video_factor", 1)
|
||||
except Exception:
|
||||
processor = None
|
||||
|
||||
|
@ -133,6 +133,7 @@ class WebChatModel(ChatModel):
|
||||
system: str,
|
||||
tools: str,
|
||||
image: Optional[Any],
|
||||
video: Optional[Any],
|
||||
max_new_tokens: int,
|
||||
top_p: float,
|
||||
temperature: float,
|
||||
@ -140,7 +141,7 @@ class WebChatModel(ChatModel):
|
||||
chatbot[-1][1] = ""
|
||||
response = ""
|
||||
for new_text in self.stream_chat(
|
||||
messages, system, tools, image, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
messages, system, tools, image, video, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
):
|
||||
response += new_text
|
||||
if tools:
|
||||
|
@ -43,8 +43,12 @@ def create_chat_box(
|
||||
system = gr.Textbox(show_label=False)
|
||||
tools = gr.Textbox(show_label=False, lines=3)
|
||||
|
||||
with gr.Column() as image_box:
|
||||
image = gr.Image(sources=["upload"], type="pil")
|
||||
with gr.Column() as mm_box:
|
||||
with gr.Tab("Image"):
|
||||
image = gr.Image(sources=["upload"], type="pil")
|
||||
|
||||
with gr.Tab("Video"):
|
||||
video = gr.Video(sources=["upload"])
|
||||
|
||||
query = gr.Textbox(show_label=False, lines=8)
|
||||
submit_btn = gr.Button(variant="primary")
|
||||
@ -63,7 +67,7 @@ def create_chat_box(
|
||||
[chatbot, messages, query],
|
||||
).then(
|
||||
engine.chatter.stream,
|
||||
[chatbot, messages, system, tools, image, max_new_tokens, top_p, temperature],
|
||||
[chatbot, messages, system, tools, image, video, max_new_tokens, top_p, temperature],
|
||||
[chatbot, messages],
|
||||
)
|
||||
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages])
|
||||
@ -76,8 +80,9 @@ def create_chat_box(
|
||||
role=role,
|
||||
system=system,
|
||||
tools=tools,
|
||||
image_box=image_box,
|
||||
mm_box=mm_box,
|
||||
image=image,
|
||||
video=video,
|
||||
query=query,
|
||||
submit_btn=submit_btn,
|
||||
max_new_tokens=max_new_tokens,
|
||||
|
@ -68,7 +68,7 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
engine.manager.get_elem_by_id("top.model_name").change(
|
||||
lambda model_name: gr.Column(visible=get_visual(model_name)),
|
||||
[engine.manager.get_elem_by_id("top.model_name")],
|
||||
[chat_elems["image_box"]],
|
||||
[chat_elems["mm_box"]],
|
||||
)
|
||||
|
||||
return elem_dict
|
||||
|
@ -59,7 +59,7 @@ class Engine:
|
||||
init_dict["train.output_dir"] = {"value": "train_{}".format(current_time)}
|
||||
init_dict["train.config_path"] = {"value": "{}.yaml".format(current_time)}
|
||||
init_dict["eval.output_dir"] = {"value": "eval_{}".format(current_time)}
|
||||
init_dict["infer.image_box"] = {"visible": False}
|
||||
init_dict["infer.mm_box"] = {"visible": False}
|
||||
|
||||
if user_config.get("last_model", None):
|
||||
init_dict["top.model_name"] = {"value": user_config["last_model"]}
|
||||
|
@ -1691,6 +1691,20 @@ LOCALES = {
|
||||
"label": "이미지 (선택 사항)",
|
||||
},
|
||||
},
|
||||
"video": {
|
||||
"en": {
|
||||
"label": "Video (optional)",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Видео (по желанию)",
|
||||
},
|
||||
"zh": {
|
||||
"label": "视频(非必填)",
|
||||
},
|
||||
"ko": {
|
||||
"label": "비디오 (선택 사항)",
|
||||
},
|
||||
},
|
||||
"query": {
|
||||
"en": {
|
||||
"placeholder": "Input...",
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -28,6 +28,8 @@ if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
from llamafactory.data.mm_plugin import BasePlugin
|
||||
|
||||
|
||||
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
||||
|
||||
@ -47,10 +49,14 @@ IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))]
|
||||
|
||||
NO_IMAGES = []
|
||||
|
||||
NO_VIDEOS = []
|
||||
|
||||
IMGLENS = [1]
|
||||
|
||||
NO_IMGLENS = [0]
|
||||
|
||||
NO_VIDLENS = [0]
|
||||
|
||||
INPUT_IDS = [0, 1, 2, 3, 4]
|
||||
|
||||
LABELS = [0, 1, 2, 3, 4]
|
||||
@ -78,92 +84,97 @@ def _load_tokenizer_module(model_name_or_path: str) -> Tuple["PreTrainedTokenize
|
||||
return tokenizer_module["tokenizer"], tokenizer_module["processor"]
|
||||
|
||||
|
||||
def _check_plugin(
|
||||
plugin: "BasePlugin",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: "ProcessorMixin",
|
||||
expected_mm_messages: Sequence[Dict[str, str]],
|
||||
expected_input_ids: List[int],
|
||||
expected_labels: List[int],
|
||||
expected_mm_inputs: Dict[str, Any],
|
||||
expected_no_mm_inputs: Dict[str, Any],
|
||||
) -> None:
|
||||
# test mm_messages
|
||||
assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, processor) == expected_mm_messages
|
||||
assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, tokenizer, processor) == (
|
||||
expected_input_ids,
|
||||
expected_labels,
|
||||
)
|
||||
_is_close(
|
||||
plugin.get_mm_inputs(IMAGES, NO_VIDEOS, IMGLENS, NO_VIDLENS, SEQLENS, processor),
|
||||
expected_mm_inputs,
|
||||
)
|
||||
# test text_messages
|
||||
assert plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, NO_VIDEOS, processor) == TEXT_MESSAGES
|
||||
assert plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, NO_VIDEOS, tokenizer, processor) == (
|
||||
INPUT_IDS,
|
||||
LABELS,
|
||||
)
|
||||
_is_close(
|
||||
plugin.get_mm_inputs(NO_IMAGES, NO_VIDEOS, NO_IMGLENS, NO_VIDLENS, SEQLENS, processor),
|
||||
expected_no_mm_inputs,
|
||||
)
|
||||
|
||||
|
||||
def test_base_plugin():
|
||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path=TINY_LLAMA)
|
||||
base_plugin = get_mm_plugin(name="base", image_token="<image>")
|
||||
# test mm_messages
|
||||
assert base_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == MM_MESSAGES
|
||||
assert base_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(base_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), {})
|
||||
# test text_messages
|
||||
assert base_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
|
||||
assert base_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(base_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor), {})
|
||||
check_inputs = {"plugin": base_plugin, "tokenizer": tokenizer, "processor": processor}
|
||||
check_inputs["expected_mm_messages"] = MM_MESSAGES
|
||||
check_inputs["expected_input_ids"] = INPUT_IDS
|
||||
check_inputs["expected_labels"] = LABELS
|
||||
check_inputs["expected_mm_inputs"] = {}
|
||||
check_inputs["expected_no_mm_inputs"] = {}
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
def test_llava_plugin():
|
||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
|
||||
llava_plugin = get_mm_plugin(name="llava", image_token="<image>")
|
||||
image_seqlen = 576
|
||||
|
||||
mm_inputs = _get_mm_inputs(processor)
|
||||
expected_mm_messages = [
|
||||
check_inputs = {"plugin": llava_plugin, "tokenizer": tokenizer, "processor": processor}
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
|
||||
for message in MM_MESSAGES
|
||||
]
|
||||
|
||||
llava_plugin = get_mm_plugin(name="llava", image_token="<image>")
|
||||
# test mm_messages
|
||||
assert llava_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
|
||||
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(llava_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), mm_inputs)
|
||||
# test text_messages
|
||||
assert llava_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
|
||||
assert llava_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(llava_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor), {})
|
||||
check_inputs["expected_input_ids"] = INPUT_IDS
|
||||
check_inputs["expected_labels"] = LABELS
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||
check_inputs["expected_no_mm_inputs"] = {}
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
def test_paligemma_plugin():
|
||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path="google/paligemma-3b-pt-224")
|
||||
paligemma_plugin = get_mm_plugin(name="paligemma", image_token="<image>")
|
||||
image_seqlen = 256
|
||||
|
||||
mm_inputs = _get_mm_inputs(processor)
|
||||
mm_inputs["token_type_ids"] = [[0] * image_seqlen + [1] * (1024 - image_seqlen)]
|
||||
expected_mm_messages = [
|
||||
check_inputs = {"plugin": paligemma_plugin, "tokenizer": tokenizer, "processor": processor}
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{key: value.replace("<image>", "") for key, value in message.items()} for message in MM_MESSAGES
|
||||
]
|
||||
expected_input_ids = [tokenizer.convert_tokens_to_ids("<image>")] * image_seqlen + INPUT_IDS
|
||||
expected_labels = [-100] * image_seqlen + LABELS
|
||||
|
||||
paligemma_plugin = get_mm_plugin(name="paligemma", image_token="<image>")
|
||||
# test mm_messages
|
||||
assert paligemma_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
|
||||
assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (
|
||||
expected_input_ids,
|
||||
expected_labels,
|
||||
)
|
||||
_is_close(paligemma_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), mm_inputs)
|
||||
# test text_messages
|
||||
assert paligemma_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
|
||||
assert paligemma_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (
|
||||
INPUT_IDS,
|
||||
LABELS,
|
||||
)
|
||||
_is_close(
|
||||
paligemma_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor),
|
||||
{"token_type_ids": [[1] * 1024]},
|
||||
)
|
||||
check_inputs["expected_input_ids"] = [tokenizer.convert_tokens_to_ids("<image>")] * image_seqlen + INPUT_IDS
|
||||
check_inputs["expected_labels"] = [-100] * image_seqlen + LABELS
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||
check_inputs["expected_mm_inputs"]["token_type_ids"] = [[0] * image_seqlen + [1] * (1024 - image_seqlen)]
|
||||
check_inputs["expected_no_mm_inputs"] = {"token_type_ids": [[1] * 1024]}
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
def test_qwen2_vl_plugin():
|
||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
|
||||
qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
|
||||
image_seqlen = 4
|
||||
|
||||
mm_inputs = _get_mm_inputs(processor)
|
||||
expected_mm_messages = [
|
||||
check_inputs = {"plugin": qwen2_vl_plugin, "tokenizer": tokenizer, "processor": processor}
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{
|
||||
key: value.replace("<image>", "<|vision_start|>{}<|vision_end|>".format("<|image_pad|>" * image_seqlen))
|
||||
for key, value in message.items()
|
||||
}
|
||||
for message in MM_MESSAGES
|
||||
]
|
||||
|
||||
qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
|
||||
# test mm_messages
|
||||
assert qwen2_vl_plugin.process_messages(MM_MESSAGES, IMAGES, processor) == expected_mm_messages
|
||||
assert qwen2_vl_plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(qwen2_vl_plugin.get_mm_inputs(IMAGES, IMGLENS, SEQLENS, processor), mm_inputs)
|
||||
# test text_messages
|
||||
assert qwen2_vl_plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, processor) == TEXT_MESSAGES
|
||||
assert qwen2_vl_plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, tokenizer, processor) == (INPUT_IDS, LABELS)
|
||||
_is_close(qwen2_vl_plugin.get_mm_inputs(NO_IMAGES, NO_IMGLENS, SEQLENS, processor), {})
|
||||
check_inputs["expected_input_ids"] = INPUT_IDS
|
||||
check_inputs["expected_labels"] = LABELS
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||
check_inputs["expected_no_mm_inputs"] = {}
|
||||
_check_plugin(**check_inputs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user