diff --git a/README.md b/README.md index 144908c1..e62a9d47 100644 --- a/README.md +++ b/README.md @@ -162,34 +162,36 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Supported Models -| Model | Model size | Template | -| ----------------------------------------------------------------- | -------------------------------- | --------- | -| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 | -| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | -| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | -| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | -| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | -| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | -| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | -| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | -| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 | -| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | -| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | -| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 | -| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava | -| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 | -| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | -| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | -| [PaliGemma](https://huggingface.co/google) | 3B | paligemma | -| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | -| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | -| [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | -| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl | -| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | -| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse | -| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi | -| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl | -| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | +| Model | Model size | Template | +| ----------------------------------------------------------------- | -------------------------------- | ---------------- | +| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 | +| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | +| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | +| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | +| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | +| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | +| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | +| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | +| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 | +| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | +| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | +| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 | +| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava | +| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next | +| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video | +| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 | +| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | +| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | +| [PaliGemma](https://huggingface.co/google) | 3B | paligemma | +| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | +| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | +| [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | +| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl | +| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | +| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse | +| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi | +| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl | +| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | > [!NOTE] > For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models. diff --git a/README_zh.md b/README_zh.md index 60ac8ee2..b5da9785 100644 --- a/README_zh.md +++ b/README_zh.md @@ -163,34 +163,36 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 ## 模型 -| 模型名 | 模型大小 | Template | -| ----------------------------------------------------------------- | -------------------------------- | --------- | -| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 | -| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | -| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | -| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | -| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | -| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | -| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | -| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | -| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 | -| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | -| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | -| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 | -| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava | -| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 | -| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | -| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | -| [PaliGemma](https://huggingface.co/google) | 3B | paligemma | -| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | -| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | -| [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | -| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl | -| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | -| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse | -| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi | -| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl | -| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | +| 模型名 | 模型大小 | Template | +| ----------------------------------------------------------------- | -------------------------------- | ---------------- | +| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 | +| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | +| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | +| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | +| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | +| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | +| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | +| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | +| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 | +| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | +| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | +| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 | +| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava | +| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next | +| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video | +| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 | +| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | +| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | +| [PaliGemma](https://huggingface.co/google) | 3B | paligemma | +| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | +| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | +| [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | +| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl | +| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | +| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse | +| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi | +| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl | +| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | > [!NOTE] > 对于所有“基座”(Base)模型,`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat)模型请务必使用**对应的模板**。 diff --git a/requirements.txt b/requirements.txt index 5f158eef..e913c58d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,3 +19,4 @@ fire packaging pyyaml numpy<2.0.0 +av diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index e22e2760..3684495b 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -4,6 +4,7 @@ from io import BytesIO from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union import numpy as np +from transformers.image_utils import get_image_size, to_numpy_array from typing_extensions import override from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER @@ -157,6 +158,7 @@ class BasePlugin: It holds num_patches == torch.prod(image_grid_thw) """ image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor) input_dict = {"images": None} # default key if len(images) != 0: images = self._regularize_images( @@ -174,10 +176,16 @@ class BasePlugin: ) input_dict["videos"] = videos - if input_dict.get("images", None) is not None or input_dict.get("videos", None) is not None: - return image_processor(**input_dict, return_tensors="pt") - else: - return {} + mm_inputs = {} + if image_processor != video_processor: + if input_dict.get("images") is not None: + mm_inputs.update(image_processor(input_dict["images"], return_tensors="pt")) + if input_dict.get("videos") is not None: + mm_inputs.update(video_processor(input_dict["videos"], return_tensors="pt")) + elif input_dict.get("images") is not None or input_dict.get("videos") is not None: # same processor (qwen2-vl) + mm_inputs.update(image_processor(**input_dict, return_tensors="pt")) + + return mm_inputs def process_messages( self, @@ -263,6 +271,122 @@ class LlavaPlugin(BasePlugin): return self._get_mm_inputs(images, videos, processor) +class LlavaNextPlugin(BasePlugin): + @override + def process_messages( + self, + messages: Sequence[Dict[str, str]], + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + processor: Optional["ProcessorMixin"], + ) -> List[Dict[str, str]]: + self._validate_input(images, videos) + num_image_tokens = 0 + messages = deepcopy(messages) + mm_inputs = self._get_mm_inputs(images, videos, processor) + if "image_sizes" in mm_inputs: + image_sizes = iter(mm_inputs["image_sizes"]) + if "pixel_values" in mm_inputs: + height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0])) + for message in messages: + content = message["content"] + while self.image_token in content: + image_size = next(image_sizes) + orig_height, orig_width = image_size + image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) + if processor.vision_feature_select_strategy == "default": + image_seqlen -= 1 + num_image_tokens += 1 + content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1) + + message["content"] = content.replace("{{image}}", self.image_token) + + if len(images) != num_image_tokens: + raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) + return messages + + @override + def get_mm_inputs( + self, + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + imglens: Sequence[int], + vidlens: Sequence[int], + seqlens: Sequence[int], + processor: Optional["ProcessorMixin"], + ) -> Dict[str, Union[List[int], "torch.Tensor"]]: + self._validate_input(images, videos) + res = self._get_mm_inputs(images, videos, processor) + return res + + +class LlavaNextVideoPlugin(BasePlugin): + @override + def process_messages( + self, + messages: Sequence[Dict[str, str]], + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + processor: Optional["ProcessorMixin"], + ) -> List[Dict[str, str]]: + self._validate_input(images, videos) + num_image_tokens = 0 + num_video_tokens = 0 + messages = deepcopy(messages) + mm_inputs = self._get_mm_inputs(images, videos, processor) + if "pixel_values" in mm_inputs: + image_sizes = iter(mm_inputs["image_sizes"]) + height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0])) + for message in messages: + content = message["content"] + + while self.image_token in content: + image_size = next(image_sizes) + orig_height, orig_width = image_size + image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) + if processor.vision_feature_select_strategy == "default": + image_seqlen -= 1 + num_image_tokens += 1 + content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1) + + message["content"] = content.replace("{{image}}", self.image_token) + + if "pixel_values_videos" in mm_inputs: + pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0]) + height, width = get_image_size(pixel_values_video[0]) + num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim + image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer + + for message in messages: + content = message["content"] + while self.video_token in content: + num_video_tokens += 1 + content = content.replace(self.video_token, "{{video}}", 1) + message["content"] = content.replace("{{video}}", self.video_token * video_seqlen) + + if len(images) != num_image_tokens: + raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) + + if len(videos) != num_video_tokens: + raise ValueError("The number of videos does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) + + return messages + + @override + def get_mm_inputs( + self, + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + imglens: Sequence[int], + vidlens: Sequence[int], + seqlens: Sequence[int], + processor: Optional["ProcessorMixin"], + ) -> Dict[str, Union[List[int], "torch.Tensor"]]: + self._validate_input(images, videos) + return self._get_mm_inputs(images, videos, processor) + + class PaliGemmaPlugin(BasePlugin): @override def process_messages( @@ -417,11 +541,77 @@ class Qwen2vlPlugin(BasePlugin): return self._get_mm_inputs(images, videos, processor) +class VideoLlavaPlugin(BasePlugin): + @override + def process_messages( + self, + messages: Sequence[Dict[str, str]], + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + processor: Optional["ProcessorMixin"], + ) -> List[Dict[str, str]]: + self._validate_input(images, videos) + num_image_tokens = 0 + num_video_tokens = 0 + messages = deepcopy(messages) + mm_inputs = self._get_mm_inputs(images, videos, processor) + num_frames = 0 + exist_images = "pixel_values_images" in mm_inputs + exist_videos = "pixel_values_videos" in mm_inputs + if exist_videos or exist_images: + if exist_images: + height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0])) + num_frames = 1 + if exist_videos: + pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0]) + height, width = get_image_size(pixel_values_video[0]) + num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim + image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1 + video_seqlen = image_seqlen * num_frames + if processor.vision_feature_select_strategy == "default": + image_seqlen -= 1 + for message in messages: + content = message["content"] + while self.image_token in content: + num_image_tokens += 1 + content = content.replace(self.image_token, "{{image}}", 1) + while self.video_token in content: + num_video_tokens += 1 + content = content.replace(self.video_token, "{{video}}", 1) + + content = content.replace("{{image}}", self.image_token * image_seqlen) + message["content"] = content.replace("{{video}}", self.video_token * video_seqlen) + + if len(images) != num_image_tokens: + raise ValueError("The number of images does not match the number of {} tokens".format(self.image_token)) + + if len(videos) != num_video_tokens: + raise ValueError("The number of videos does not match the number of {} tokens".format(self.video_token)) + + return messages + + @override + def get_mm_inputs( + self, + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + imglens: Sequence[int], + vidlens: Sequence[int], + seqlens: Sequence[int], + processor: Optional["ProcessorMixin"], + ) -> Dict[str, Union[List[int], "torch.Tensor"]]: + self._validate_input(images, videos) + return self._get_mm_inputs(images, videos, processor) + + PLUGINS = { "base": BasePlugin, "llava": LlavaPlugin, + "llava_next": LlavaNextPlugin, + "llava_next_video": LlavaNextVideoPlugin, "paligemma": PaliGemmaPlugin, "qwen2_vl": Qwen2vlPlugin, + "video_llava": VideoLlavaPlugin, } diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index c7d47ebc..7a10a0e3 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -760,6 +760,107 @@ _register_template( ) +_register_template( + name="llava_next", + format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), + default_system=( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ), + mm_plugin=get_mm_plugin(name="llava_next", image_token=""), +) + + +_register_template( + name="llava_next_llama3", + format_user=StringFormatter( + slots=[ + ( + "<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + ] + ), + format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]), + format_observation=StringFormatter( + slots=[ + ( + "<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + ] + ), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=["<|eot_id|>"], + replace_eos=True, + replace_jinja_template=False, + mm_plugin=get_mm_plugin(name="llava_next", image_token=""), +) + + +_register_template( + name="llava_next_mistral", + format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + mm_plugin=get_mm_plugin(name="llava_next", image_token=""), +) + + +_register_template( + name="llava_next_qwen", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + default_system="You are a helpful assistant.", + stop_words=["<|im_end|>"], + replace_eos=True, + replace_jinja_template=False, + mm_plugin=get_mm_plugin(name="llava_next", image_token=""), +) + + +_register_template( + name="llava_next_yi", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + stop_words=["<|im_end|>"], + replace_eos=True, + mm_plugin=get_mm_plugin(name="llava_next", image_token=""), +) + + +_register_template( + name="llava_next_video", + format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), + default_system=( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ), + mm_plugin=get_mm_plugin(name="llava_next_video", image_token="", video_token="