mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-15 08:08:09 +08:00
Support new features of MiniCPM-V (#6626)
* fix template name * tiny fix * support minicpm-o-2.6 Former-commit-id: 53034a61c7654358f46916cbc370910fb2aeff3b
This commit is contained in:
parent
2a05941b14
commit
ae32c148d1
@ -209,7 +209,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
||||||
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
||||||
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
||||||
| [MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | cpm_v |
|
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_v |
|
||||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||||
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
|
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
|
||||||
|
@ -210,7 +210,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
|||||||
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
||||||
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
||||||
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
||||||
| [MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | cpm_v |
|
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_v |
|
||||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||||
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
|
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
|
||||||
|
10
setup.py
10
setup.py
@ -59,6 +59,16 @@ extra_require = {
|
|||||||
"badam": ["badam>=1.2.1"],
|
"badam": ["badam>=1.2.1"],
|
||||||
"adam-mini": ["adam-mini"],
|
"adam-mini": ["adam-mini"],
|
||||||
"qwen": ["transformers_stream_generator"],
|
"qwen": ["transformers_stream_generator"],
|
||||||
|
"minicpm_v": [
|
||||||
|
"soundfile",
|
||||||
|
"torchvision",
|
||||||
|
"torchaudio",
|
||||||
|
"vector_quantize_pytorch",
|
||||||
|
"vocos",
|
||||||
|
"msgpack",
|
||||||
|
"referencing",
|
||||||
|
"jsonschema_specifications",
|
||||||
|
],
|
||||||
"modelscope": ["modelscope"],
|
"modelscope": ["modelscope"],
|
||||||
"openmind": ["openmind"],
|
"openmind": ["openmind"],
|
||||||
"swanlab": ["swanlab"],
|
"swanlab": ["swanlab"],
|
||||||
|
@ -153,9 +153,8 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
features = features.data # use default_collate() instead of BatchEncoding.to()
|
features = features.data # use default_collate() instead of BatchEncoding.to()
|
||||||
|
|
||||||
if "image_bound" in features: # for minicpmv inputs
|
if "image_bound" in features: # for minicpmv inputs
|
||||||
features["position_ids"] = (
|
bsz, seq_length = features["input_ids"].shape
|
||||||
torch.arange(features["input_ids"].size(1)).long().unsqueeze(0).expand_as(features["input_ids"])
|
features["position_ids"] = torch.arange(seq_length).long().repeat(bsz, 1)
|
||||||
)
|
|
||||||
return {"data": features, "labels": features["labels"]}
|
return {"data": features, "labels": features["labels"]}
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
@ -254,156 +254,6 @@ class BasePlugin:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
class CpmVPlugin(BasePlugin):
|
|
||||||
@override
|
|
||||||
def process_messages(
|
|
||||||
self,
|
|
||||||
messages: Sequence[Dict[str, str]],
|
|
||||||
images: Sequence["ImageInput"],
|
|
||||||
videos: Sequence["VideoInput"],
|
|
||||||
processor: Optional["ProcessorMixin"],
|
|
||||||
) -> List[Dict[str, str]]:
|
|
||||||
self._validate_input(images, videos)
|
|
||||||
num_image_tokens = 0
|
|
||||||
num_video_tokens = 0
|
|
||||||
messages = deepcopy(messages)
|
|
||||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
|
||||||
mm_inputs = {}
|
|
||||||
if len(images) != 0 and len(videos) != 0:
|
|
||||||
raise ValueError("MiniCPM-V model does not support input images and videos at the same time.")
|
|
||||||
|
|
||||||
if len(videos) != 0:
|
|
||||||
max_slice_nums = 2
|
|
||||||
use_image_id = False
|
|
||||||
mm_inputs = self._get_mm_inputs([], videos, processor)
|
|
||||||
else:
|
|
||||||
max_slice_nums = image_processor.max_slice_nums
|
|
||||||
use_image_id = image_processor.use_image_id
|
|
||||||
|
|
||||||
for message in messages:
|
|
||||||
content = message["content"]
|
|
||||||
while IMAGE_PLACEHOLDER in content:
|
|
||||||
num_image_tokens += 1
|
|
||||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
|
||||||
|
|
||||||
while VIDEO_PLACEHOLDER in content:
|
|
||||||
video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
|
|
||||||
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
|
|
||||||
num_video_tokens += 1
|
|
||||||
|
|
||||||
message["content"] = content.replace("{{image}}", "(<image>./</image>)")
|
|
||||||
|
|
||||||
if num_image_tokens > 0:
|
|
||||||
mm_inputs = self._get_mm_inputs(images, [], processor)
|
|
||||||
|
|
||||||
if mm_inputs:
|
|
||||||
pattern = "(<image>./</image>)"
|
|
||||||
image_sizes = mm_inputs["image_sizes"]
|
|
||||||
|
|
||||||
for index, message in enumerate(messages):
|
|
||||||
text = message["content"]
|
|
||||||
image_tags = re.findall(pattern, text)
|
|
||||||
text_chunks = text.split(pattern)
|
|
||||||
final_text = ""
|
|
||||||
for i in range(len(image_tags)):
|
|
||||||
final_text = (
|
|
||||||
final_text
|
|
||||||
+ text_chunks[i]
|
|
||||||
+ image_processor.get_slice_image_placeholder(
|
|
||||||
image_sizes[0][i], i, max_slice_nums, use_image_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
final_text += text_chunks[-1]
|
|
||||||
messages[index]["content"] = final_text
|
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
|
||||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
if len(videos) != num_video_tokens:
|
|
||||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
return messages
|
|
||||||
|
|
||||||
@override
|
|
||||||
def _get_mm_inputs(
|
|
||||||
self,
|
|
||||||
images: Sequence["ImageInput"],
|
|
||||||
videos: Sequence["VideoInput"],
|
|
||||||
processor: "ProcessorMixin",
|
|
||||||
**kwargs,
|
|
||||||
) -> Dict[str, "torch.Tensor"]:
|
|
||||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
|
||||||
|
|
||||||
mm_inputs = {}
|
|
||||||
if len(images) != 0:
|
|
||||||
images = self._regularize_images(
|
|
||||||
images,
|
|
||||||
image_resolution=getattr(processor, "image_resolution", 512 * 512),
|
|
||||||
)
|
|
||||||
if "valid_image_nums_ls" in kwargs:
|
|
||||||
valid_image_nums_ls = kwargs["valid_image_nums_ls"]
|
|
||||||
new_images = []
|
|
||||||
idx = 0
|
|
||||||
for valid_image_nums in valid_image_nums_ls:
|
|
||||||
new_images.append(images[idx : idx + valid_image_nums])
|
|
||||||
idx += valid_image_nums
|
|
||||||
|
|
||||||
images = new_images
|
|
||||||
|
|
||||||
image_inputs = image_processor(
|
|
||||||
images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt"
|
|
||||||
)
|
|
||||||
mm_inputs.update(image_inputs)
|
|
||||||
|
|
||||||
if len(videos) != 0:
|
|
||||||
videos = self._regularize_videos(
|
|
||||||
videos,
|
|
||||||
image_resolution=getattr(processor, "video_resolution", 128 * 128),
|
|
||||||
video_fps=getattr(processor, "video_fps", 2.0),
|
|
||||||
video_maxlen=getattr(processor, "video_maxlen", 64),
|
|
||||||
)
|
|
||||||
video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
|
|
||||||
mm_inputs.update(video_inputs)
|
|
||||||
|
|
||||||
return mm_inputs
|
|
||||||
|
|
||||||
@override
|
|
||||||
def get_mm_inputs(
|
|
||||||
self,
|
|
||||||
images: Sequence["ImageInput"],
|
|
||||||
videos: Sequence["VideoInput"],
|
|
||||||
imglens: Sequence[int],
|
|
||||||
vidlens: Sequence[int],
|
|
||||||
batch_ids: Sequence[List[int]],
|
|
||||||
processor: Optional["ProcessorMixin"],
|
|
||||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
|
||||||
self._validate_input(images, videos)
|
|
||||||
image_bounds_list = []
|
|
||||||
valid_image_nums_ls = []
|
|
||||||
for input_ids in batch_ids:
|
|
||||||
input_ids_ = torch.tensor(input_ids)
|
|
||||||
start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (
|
|
||||||
input_ids_ == processor.tokenizer.slice_start_id
|
|
||||||
)
|
|
||||||
end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id)
|
|
||||||
image_start_tokens = torch.where(start_cond)[0]
|
|
||||||
image_start_tokens += 1
|
|
||||||
image_end_tokens = torch.where(end_cond)[0]
|
|
||||||
valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
|
|
||||||
valid_image_nums_ls.append(valid_image_nums)
|
|
||||||
image_bounds = torch.hstack(
|
|
||||||
[
|
|
||||||
image_start_tokens[:valid_image_nums].unsqueeze(-1),
|
|
||||||
image_end_tokens[:valid_image_nums].unsqueeze(-1),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
image_bounds_list.append(image_bounds)
|
|
||||||
|
|
||||||
mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls)
|
|
||||||
mm_inputs.update({"image_bound": image_bounds_list})
|
|
||||||
return mm_inputs
|
|
||||||
|
|
||||||
|
|
||||||
class LlavaPlugin(BasePlugin):
|
class LlavaPlugin(BasePlugin):
|
||||||
@override
|
@override
|
||||||
def process_messages(
|
def process_messages(
|
||||||
@ -567,6 +417,156 @@ class LlavaNextVideoPlugin(BasePlugin):
|
|||||||
return self._get_mm_inputs(images, videos, processor)
|
return self._get_mm_inputs(images, videos, processor)
|
||||||
|
|
||||||
|
|
||||||
|
class MiniCPMVPlugin(BasePlugin):
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
num_image_tokens = 0
|
||||||
|
num_video_tokens = 0
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||||
|
mm_inputs = {}
|
||||||
|
if len(images) != 0 and len(videos) != 0:
|
||||||
|
raise ValueError("MiniCPM-V model does not support input images and videos at the same time.")
|
||||||
|
|
||||||
|
if len(videos) != 0:
|
||||||
|
max_slice_nums = 2
|
||||||
|
use_image_id = False
|
||||||
|
mm_inputs = self._get_mm_inputs([], videos, processor)
|
||||||
|
else:
|
||||||
|
max_slice_nums = image_processor.max_slice_nums
|
||||||
|
use_image_id = image_processor.use_image_id
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
while IMAGE_PLACEHOLDER in content:
|
||||||
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
||||||
|
num_image_tokens += 1
|
||||||
|
|
||||||
|
while VIDEO_PLACEHOLDER in content:
|
||||||
|
video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
|
||||||
|
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
|
||||||
|
num_video_tokens += 1
|
||||||
|
|
||||||
|
message["content"] = content.replace("{{image}}", "(<image>./</image>)")
|
||||||
|
|
||||||
|
if num_image_tokens > 0:
|
||||||
|
mm_inputs = self._get_mm_inputs(images, [], processor)
|
||||||
|
|
||||||
|
if mm_inputs:
|
||||||
|
pattern = "(<image>./</image>)"
|
||||||
|
image_sizes = mm_inputs["image_sizes"]
|
||||||
|
|
||||||
|
for index, message in enumerate(messages):
|
||||||
|
text = message["content"]
|
||||||
|
image_tags = re.findall(pattern, text)
|
||||||
|
text_chunks = text.split(pattern)
|
||||||
|
final_text = ""
|
||||||
|
for i in range(len(image_tags)):
|
||||||
|
final_text = (
|
||||||
|
final_text
|
||||||
|
+ text_chunks[i]
|
||||||
|
+ image_processor.get_slice_image_placeholder(
|
||||||
|
image_sizes[0][i], i, max_slice_nums, use_image_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
final_text += text_chunks[-1]
|
||||||
|
messages[index]["content"] = final_text
|
||||||
|
|
||||||
|
if len(images) != num_image_tokens:
|
||||||
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
if len(videos) != num_video_tokens:
|
||||||
|
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: "ProcessorMixin",
|
||||||
|
**kwargs,
|
||||||
|
) -> Dict[str, "torch.Tensor"]:
|
||||||
|
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||||
|
mm_inputs = {}
|
||||||
|
if len(images) != 0:
|
||||||
|
images = self._regularize_images(
|
||||||
|
images,
|
||||||
|
image_resolution=getattr(processor, "image_resolution", 512 * 512),
|
||||||
|
)
|
||||||
|
if "valid_image_nums_ls" in kwargs:
|
||||||
|
valid_image_nums_ls = kwargs["valid_image_nums_ls"]
|
||||||
|
new_images = []
|
||||||
|
idx = 0
|
||||||
|
for valid_image_nums in valid_image_nums_ls:
|
||||||
|
new_images.append(images[idx : idx + valid_image_nums])
|
||||||
|
idx += valid_image_nums
|
||||||
|
|
||||||
|
images = new_images
|
||||||
|
|
||||||
|
image_inputs = image_processor(
|
||||||
|
images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt"
|
||||||
|
)
|
||||||
|
mm_inputs.update(image_inputs)
|
||||||
|
|
||||||
|
if len(videos) != 0:
|
||||||
|
videos = self._regularize_videos(
|
||||||
|
videos,
|
||||||
|
image_resolution=getattr(processor, "video_resolution", 128 * 128),
|
||||||
|
video_fps=getattr(processor, "video_fps", 2.0),
|
||||||
|
video_maxlen=getattr(processor, "video_maxlen", 64),
|
||||||
|
)
|
||||||
|
video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
|
||||||
|
mm_inputs.update(video_inputs)
|
||||||
|
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
imglens: Sequence[int],
|
||||||
|
vidlens: Sequence[int],
|
||||||
|
batch_ids: Sequence[List[int]],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
image_bounds_list = []
|
||||||
|
valid_image_nums_ls = []
|
||||||
|
for input_ids in batch_ids:
|
||||||
|
input_ids_ = torch.tensor(input_ids)
|
||||||
|
start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (
|
||||||
|
input_ids_ == processor.tokenizer.slice_start_id
|
||||||
|
)
|
||||||
|
end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id)
|
||||||
|
image_start_tokens = torch.where(start_cond)[0]
|
||||||
|
image_start_tokens += 1
|
||||||
|
image_end_tokens = torch.where(end_cond)[0]
|
||||||
|
valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
|
||||||
|
valid_image_nums_ls.append(valid_image_nums)
|
||||||
|
image_bounds = torch.hstack(
|
||||||
|
[
|
||||||
|
image_start_tokens[:valid_image_nums].unsqueeze(-1),
|
||||||
|
image_end_tokens[:valid_image_nums].unsqueeze(-1),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
image_bounds_list.append(image_bounds)
|
||||||
|
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls)
|
||||||
|
mm_inputs.update({"image_bound": image_bounds_list})
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
class PaliGemmaPlugin(BasePlugin):
|
class PaliGemmaPlugin(BasePlugin):
|
||||||
@override
|
@override
|
||||||
def process_messages(
|
def process_messages(
|
||||||
@ -945,10 +945,10 @@ class MllamaPlugin(BasePlugin):
|
|||||||
|
|
||||||
PLUGINS = {
|
PLUGINS = {
|
||||||
"base": BasePlugin,
|
"base": BasePlugin,
|
||||||
"cpm_v": CpmVPlugin,
|
|
||||||
"llava": LlavaPlugin,
|
"llava": LlavaPlugin,
|
||||||
"llava_next": LlavaNextPlugin,
|
"llava_next": LlavaNextPlugin,
|
||||||
"llava_next_video": LlavaNextVideoPlugin,
|
"llava_next_video": LlavaNextVideoPlugin,
|
||||||
|
"minicpm_v": MiniCPMVPlugin,
|
||||||
"paligemma": PaliGemmaPlugin,
|
"paligemma": PaliGemmaPlugin,
|
||||||
"pixtral": PixtralPlugin,
|
"pixtral": PixtralPlugin,
|
||||||
"qwen2_vl": Qwen2vlPlugin,
|
"qwen2_vl": Qwen2vlPlugin,
|
||||||
|
@ -576,17 +576,6 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# copied from chatml template
|
|
||||||
_register_template(
|
|
||||||
name="cpm_v",
|
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
|
||||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
|
||||||
stop_words=["<|im_end|>"],
|
|
||||||
mm_plugin=get_mm_plugin(name="cpm_v", image_token="<image>", video_token="<video>"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# copied from chatml template
|
# copied from chatml template
|
||||||
_register_template(
|
_register_template(
|
||||||
name="dbrx",
|
name="dbrx",
|
||||||
@ -961,6 +950,17 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# copied from chatml template
|
||||||
|
_register_template(
|
||||||
|
name="minicpm_v",
|
||||||
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
|
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||||
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
|
stop_words=["<|im_end|>"],
|
||||||
|
mm_plugin=get_mm_plugin(name="minicpm_v", image_token="<image>", video_token="<video>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="mistral",
|
name="mistral",
|
||||||
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
|
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
|
||||||
|
@ -1163,6 +1163,17 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"MiniCPM-o-2_6-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "openbmb/MiniCPM-o-2_6",
|
||||||
|
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-o-2_6",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="minicpm_v",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"MiniCPM-V-2_6-Chat": {
|
"MiniCPM-V-2_6-Chat": {
|
||||||
@ -1170,7 +1181,7 @@ register_model_group(
|
|||||||
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-V-2_6",
|
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-V-2_6",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
template="cpm_v",
|
template="minicpm_v",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -253,6 +253,7 @@ _register_composite_model(
|
|||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="minicpmv",
|
model_type="minicpmv",
|
||||||
vision_model_keys=["vpm", "apm", "resampler", "tts"],
|
vision_model_keys=["vpm", "apm", "resampler", "tts"],
|
||||||
|
language_model_keys=["llm"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user