mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 20:30:36 +08:00
[model] add llama4 (#7611)
This commit is contained in:
@@ -466,6 +466,73 @@ class Gemma3Plugin(BasePlugin):
|
||||
return mm_inputs
|
||||
|
||||
|
||||
@dataclass
|
||||
class Llama4Plugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
if self.expand_mm_tokens:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
if "pixel_values" in mm_inputs:
|
||||
image_height, image_width = mm_inputs["pixel_values"][0].shape[-2:]
|
||||
num_patches_per_chunk = int(
|
||||
(image_height // processor.patch_size)
|
||||
* (image_width // processor.patch_size)
|
||||
// processor.downsample_ratio
|
||||
)
|
||||
aspect_ratios = mm_inputs.pop("aspect_ratios")
|
||||
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
placeholder_count = content.count(IMAGE_PLACEHOLDER)
|
||||
if self.expand_mm_tokens:
|
||||
prompt_splits = content.split(IMAGE_PLACEHOLDER)
|
||||
new_content = []
|
||||
for local_image_index, split_part in enumerate(prompt_splits):
|
||||
new_content.append(split_part)
|
||||
if local_image_index < placeholder_count:
|
||||
tokens_for_this_image = processor._prompt_split_image(
|
||||
aspect_ratios[num_image_tokens], num_patches_per_chunk
|
||||
)
|
||||
num_image_tokens += 1
|
||||
new_content.append(tokens_for_this_image)
|
||||
|
||||
content = "".join(new_content)
|
||||
|
||||
message["content"] = content
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
imglens: list[int],
|
||||
vidlens: list[int],
|
||||
audlens: list[int],
|
||||
batch_ids: list[list[int]],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
mm_inputs.pop("aspect_ratios", None)
|
||||
return mm_inputs
|
||||
|
||||
|
||||
@dataclass
|
||||
class LlavaPlugin(BasePlugin):
|
||||
@override
|
||||
@@ -1485,6 +1552,7 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
PLUGINS = {
|
||||
"base": BasePlugin,
|
||||
"gemma3": Gemma3Plugin,
|
||||
"llama4": Llama4Plugin,
|
||||
"llava": LlavaPlugin,
|
||||
"llava_next": LlavaNextPlugin,
|
||||
"llava_next_video": LlavaNextVideoPlugin,
|
||||
|
||||
Reference in New Issue
Block a user